import asyncio import logging import math import random import re import time import uuid import xml.etree.ElementTree as ET from typing import Any, AsyncGenerator, Coroutine, Optional from uuid import UUID from xml.etree.ElementTree import Element from core.base import ( DocumentChunk, GraphExtraction, GraphExtractionStatus, R2RDocumentProcessingError, ) from core.base.abstractions import ( Community, Entity, GenerationConfig, GraphConstructionStatus, R2RException, Relationship, StoreType, ) from core.base.api.models import GraphResponse from ..abstractions import R2RProviders from ..config import R2RConfig from .base import Service logger = logging.getLogger() MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128 async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]: """Collects all results from an async generator into a list.""" results = [] async for res in result_gen: results.append(res) return results class GraphService(Service): def __init__( self, config: R2RConfig, providers: R2RProviders, ): super().__init__( config, providers, ) async def create_entity( self, name: str, description: str, parent_id: UUID, category: Optional[str] = None, metadata: Optional[dict] = None, ) -> Entity: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return await self.providers.database.graphs_handler.entities.create( name=name, parent_id=parent_id, store_type=StoreType.GRAPHS, category=category, description=description, description_embedding=description_embedding, metadata=metadata, ) async def update_entity( self, entity_id: UUID, name: Optional[str] = None, description: Optional[str] = None, category: Optional[str] = None, metadata: Optional[dict] = None, ) -> Entity: description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return await self.providers.database.graphs_handler.entities.update( entity_id=entity_id, store_type=StoreType.GRAPHS, name=name, description=description, description_embedding=description_embedding, category=category, metadata=metadata, ) async def delete_entity( self, parent_id: UUID, entity_id: UUID, ): return await self.providers.database.graphs_handler.entities.delete( parent_id=parent_id, entity_ids=[entity_id], store_type=StoreType.GRAPHS, ) async def get_entities( self, parent_id: UUID, offset: int, limit: int, entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, include_embeddings: bool = False, ): return await self.providers.database.graphs_handler.get_entities( parent_id=parent_id, offset=offset, limit=limit, entity_ids=entity_ids, entity_names=entity_names, include_embeddings=include_embeddings, ) async def create_relationship( self, subject: str, subject_id: UUID, predicate: str, object: str, object_id: UUID, parent_id: UUID, description: str | None = None, weight: float | None = 1.0, metadata: Optional[dict[str, Any] | str] = None, ) -> Relationship: description_embedding = None if description: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return ( await self.providers.database.graphs_handler.relationships.create( subject=subject, subject_id=subject_id, predicate=predicate, object=object, object_id=object_id, parent_id=parent_id, description=description, description_embedding=description_embedding, weight=weight, metadata=metadata, store_type=StoreType.GRAPHS, ) ) async def delete_relationship( self, parent_id: UUID, relationship_id: UUID, ): return ( await self.providers.database.graphs_handler.relationships.delete( parent_id=parent_id, relationship_ids=[relationship_id], store_type=StoreType.GRAPHS, ) ) async def update_relationship( self, relationship_id: UUID, subject: Optional[str] = None, subject_id: Optional[UUID] = None, predicate: Optional[str] = None, object: Optional[str] = None, object_id: Optional[UUID] = None, description: Optional[str] = None, weight: Optional[float] = None, metadata: Optional[dict[str, Any] | str] = None, ) -> Relationship: description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return ( await self.providers.database.graphs_handler.relationships.update( relationship_id=relationship_id, subject=subject, subject_id=subject_id, predicate=predicate, object=object, object_id=object_id, description=description, description_embedding=description_embedding, weight=weight, metadata=metadata, store_type=StoreType.GRAPHS, ) ) async def get_relationships( self, parent_id: UUID, offset: int, limit: int, relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, ): return await self.providers.database.graphs_handler.relationships.get( parent_id=parent_id, store_type=StoreType.GRAPHS, offset=offset, limit=limit, relationship_ids=relationship_ids, entity_names=entity_names, ) async def create_community( self, parent_id: UUID, name: str, summary: str, findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], ) -> Community: description_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graphs_handler.communities.create( parent_id=parent_id, store_type=StoreType.GRAPHS, name=name, summary=summary, description_embedding=description_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, ) async def update_community( self, community_id: UUID, name: Optional[str], summary: Optional[str], findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], ) -> Community: summary_embedding = None if summary is not None: summary_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graphs_handler.communities.update( community_id=community_id, store_type=StoreType.GRAPHS, name=name, summary=summary, summary_embedding=summary_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, ) async def delete_community( self, parent_id: UUID, community_id: UUID, ) -> None: await self.providers.database.graphs_handler.communities.delete( parent_id=parent_id, community_id=community_id, ) async def get_communities( self, parent_id: UUID, offset: int, limit: int, community_ids: Optional[list[UUID]] = None, community_names: Optional[list[str]] = None, include_embeddings: bool = False, ): return await self.providers.database.graphs_handler.get_communities( parent_id=parent_id, offset=offset, limit=limit, community_ids=community_ids, include_embeddings=include_embeddings, ) async def list_graphs( self, offset: int, limit: int, graph_ids: Optional[list[UUID]] = None, collection_id: Optional[UUID] = None, ) -> dict[str, list[GraphResponse] | int]: return await self.providers.database.graphs_handler.list_graphs( offset=offset, limit=limit, filter_graph_ids=graph_ids, filter_collection_id=collection_id, ) async def update_graph( self, collection_id: UUID, name: Optional[str] = None, description: Optional[str] = None, ) -> GraphResponse: return await self.providers.database.graphs_handler.update( collection_id=collection_id, name=name, description=description, ) async def reset_graph(self, id: UUID) -> bool: await self.providers.database.graphs_handler.reset( parent_id=id, ) await self.providers.database.documents_handler.set_workflow_status( id=id, status_type="graph_cluster_status", status=GraphConstructionStatus.PENDING, ) return True async def get_document_ids_for_create_graph( self, collection_id: UUID, **kwargs, ): document_status_filter = [ GraphExtractionStatus.PENDING, GraphExtractionStatus.FAILED, ] return await self.providers.database.documents_handler.get_document_ids_by_status( status_type="extraction_status", status=[str(ele) for ele in document_status_filter], collection_id=collection_id, ) async def graph_search_results_entity_description( self, document_id: UUID, max_description_input_length: int, batch_size: int = 256, **kwargs, ): """A new implementation of the old GraphDescriptionPipe logic inline. No references to pipe objects. We: 1) Count how many entities are in the document 2) Process them in batches of `batch_size` 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions """ start_time = time.time() logger.info( f"GraphService: Running graph_search_results_entity_description for doc={document_id}" ) # Count how many doc-entities exist entity_count = ( await self.providers.database.graphs_handler.get_entity_count( document_id=document_id, distinct=True, entity_table_name="documents_entities", # or whichever table ) ) logger.info( f"GraphService: Found {entity_count} doc-entities to describe." ) all_results = [] num_batches = math.ceil(entity_count / batch_size) for i in range(num_batches): offset = i * batch_size limit = batch_size logger.info( f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}" ) # Actually handle describing the entities in the batch # We'll collect them into a list via an async generator gen = self._describe_entities_in_document_batch( document_id=document_id, offset=offset, limit=limit, max_description_input_length=max_description_input_length, ) batch_results = await _collect_async_results(gen) all_results.append(batch_results) # Mark the doc's extraction status as success await self.providers.database.documents_handler.set_workflow_status( id=document_id, status_type="extraction_status", status=GraphExtractionStatus.SUCCESS, ) logger.info( f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s." ) return all_results async def _describe_entities_in_document_batch( self, document_id: UUID, offset: int, limit: int, max_description_input_length: int, ) -> AsyncGenerator[str, None]: """Core logic that replaces GraphDescriptionPipe._run_logic for a particular document/batch. Yields entity-names or some textual result as each entity is updated. """ start_time = time.time() logger.info( f"Started describing doc={document_id}, offset={offset}, limit={limit}" ) # 1) Get the "entity map" from the DB entity_map = ( await self.providers.database.graphs_handler.get_entity_map( offset=offset, limit=limit, document_id=document_id ) ) total_entities = len(entity_map) logger.info( f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}." ) # 2) For each entity name in the map, we gather sub-entities and relationships tasks: list[Coroutine[Any, Any, str]] = [] tasks.extend( self._process_entity_for_description( entities=[ entity if isinstance(entity, Entity) else Entity(**entity) for entity in entity_info["entities"] ], relationships=[ rel if isinstance(rel, Relationship) else Relationship(**rel) for rel in entity_info["relationships"] ], document_id=document_id, max_description_input_length=max_description_input_length, ) for entity_name, entity_info in entity_map.items() ) # 3) Wait for all tasks, yield as they complete idx = 0 for coro in asyncio.as_completed(tasks): result = await coro idx += 1 if idx % 100 == 0: logger.info( f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}" ) yield result logger.info( f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s." ) async def _process_entity_for_description( self, entities: list[Entity], relationships: list[Relationship], document_id: UUID, max_description_input_length: int, ) -> str: """Adapted from the old process_entity function in GraphDescriptionPipe. If entity has no description, call an LLM to create one, then store it. Returns the name of the top entity (or could store more details). """ def truncate_info(info_list: list[str], max_length: int) -> str: """Shuffles lines of info to try to keep them distinct, then accumulates until hitting max_length.""" random.shuffle(info_list) truncated_info = "" current_length = 0 for info in info_list: if current_length + len(info) > max_length: break truncated_info += info + "\n" current_length += len(info) return truncated_info # Grab a doc-level summary (optional) to feed into the prompt response = await self.providers.database.documents_handler.get_documents_overview( offset=0, limit=1, filter_document_ids=[document_id], ) document_summary = ( response["results"][0].summary if response["results"] else None ) # Synthesize a minimal “entity info” string + relationship summary entity_info = [ f"{e.name}, {e.description or 'NONE'}" for e in entities ] relationships_txt = [ f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}" for i, r in enumerate(relationships) ] # We'll describe only the first entity for simplicity # or you could do them all if needed main_entity = entities[0] if not main_entity.description: # We only call LLM if the entity is missing a description messages = await self.providers.database.prompts_handler.get_message_payload( task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, task_inputs={ "document_summary": document_summary, "entity_info": truncate_info( entity_info, max_description_input_length ), "relationships_txt": truncate_info( relationships_txt, max_description_input_length ), }, ) # Call the LLM gen_config = ( self.providers.database.config.graph_creation_settings.generation_config or GenerationConfig(model=self.config.app.fast_llm) ) llm_resp = await self.providers.llm.aget_completion( messages=messages, generation_config=gen_config, ) new_description = llm_resp.choices[0].message.content if not new_description: logger.error( f"No LLM description returned for entity={main_entity.name}" ) return main_entity.name # create embedding embed = ( await self.providers.embedding.async_get_embeddings( [new_description] ) )[0] # update DB main_entity.description = new_description main_entity.description_embedding = embed # Use a method to upsert entity in `documents_entities` or your table await self.providers.database.graphs_handler.add_entities( [main_entity], table_name="documents_entities", ) return main_entity.name async def graph_search_results_clustering( self, collection_id: UUID, generation_config: GenerationConfig, leiden_params: dict, **kwargs, ): """ Replacement for the old GraphClusteringPipe logic: 1) call perform_graph_clustering on the DB 2) return the result """ logger.info( f"Running inline clustering for collection={collection_id} with params={leiden_params}" ) return await self._perform_graph_clustering( collection_id=collection_id, generation_config=generation_config, leiden_params=leiden_params, ) async def _perform_graph_clustering( self, collection_id: UUID, generation_config: GenerationConfig, leiden_params: dict, ) -> dict: """The actual clustering logic (previously in GraphClusteringPipe.cluster_graph_search_results).""" num_communities = await self.providers.database.graphs_handler.perform_graph_clustering( collection_id=collection_id, leiden_params=leiden_params, ) return {"num_communities": num_communities} async def graph_search_results_community_summary( self, offset: int, limit: int, max_summary_input_length: int, generation_config: GenerationConfig, collection_id: UUID, leiden_params: Optional[dict] = None, **kwargs, ): """Replacement for the old GraphCommunitySummaryPipe logic. Summarizes communities after clustering. Returns an async generator or you can collect into a list. """ logger.info( f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}" ) # We call an internal function that yields summaries gen = self._summarize_communities( offset=offset, limit=limit, max_summary_input_length=max_summary_input_length, generation_config=generation_config, collection_id=collection_id, leiden_params=leiden_params or {}, ) return await _collect_async_results(gen) async def _summarize_communities( self, offset: int, limit: int, max_summary_input_length: int, generation_config: GenerationConfig, collection_id: UUID, leiden_params: dict, ) -> AsyncGenerator[dict, None]: """Does the community summary logic from GraphCommunitySummaryPipe._run_logic. Yields each summary dictionary as it completes. """ start_time = time.time() logger.info( f"Starting community summarization for collection={collection_id}" ) # get all entities & relationships ( all_entities, _, ) = await self.providers.database.graphs_handler.get_entities( parent_id=collection_id, offset=0, limit=-1, include_embeddings=False, ) ( all_relationships, _, ) = await self.providers.database.graphs_handler.get_relationships( parent_id=collection_id, offset=0, limit=-1, include_embeddings=False, ) # We can optionally re-run the clustering to produce fresh community assignments ( _, community_clusters, ) = await self.providers.database.graphs_handler._cluster_and_add_community_info( relationships=all_relationships, leiden_params=leiden_params, collection_id=collection_id, ) # Group clusters clusters: dict[Any, list[str]] = {} for item in community_clusters: cluster_id = item["cluster"] node_name = item["node"] clusters.setdefault(cluster_id, []).append(node_name) # create an async job for each cluster tasks: list[Coroutine[Any, Any, dict]] = [] tasks.extend( self._process_community_summary( community_id=uuid.uuid4(), nodes=nodes, all_entities=all_entities, all_relationships=all_relationships, max_summary_input_length=max_summary_input_length, generation_config=generation_config, collection_id=collection_id, ) for nodes in clusters.values() ) total_jobs = len(tasks) results_returned = 0 total_errors = 0 for coro in asyncio.as_completed(tasks): summary = await coro results_returned += 1 if results_returned % 50 == 0: logger.info( f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s" ) if "error" in summary: total_errors += 1 yield summary if total_errors > 0: logger.warning( f"{total_errors} communities failed summarization out of {total_jobs}" ) async def _process_community_summary( self, community_id: UUID, nodes: list[str], all_entities: list[Entity], all_relationships: list[Relationship], max_summary_input_length: int, generation_config: GenerationConfig, collection_id: UUID, ) -> dict: """ Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block, parse it, store the result as a community in DB. """ # (Equivalent to process_community in old code) # fetch the collection description (optional) response = await self.providers.database.collections_handler.get_collections_overview( offset=0, limit=1, filter_collection_ids=[collection_id], ) collection_description = ( response["results"][0].description if response["results"] else None # type: ignore ) # filter out relevant entities / relationships entities = [e for e in all_entities if e.name in nodes] relationships = [ r for r in all_relationships if r.subject in nodes and r.object in nodes ] if not entities and not relationships: return { "community_id": community_id, "error": f"No data in this community (nodes={nodes})", } # Create the big input text for the LLM input_text = await self._community_summary_prompt( entities, relationships, max_summary_input_length, ) # Attempt up to 3 times to parse for attempt in range(3): try: # Build the prompt messages = await self.providers.database.prompts_handler.get_message_payload( task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt, task_inputs={ "collection_description": collection_description, "input_text": input_text, }, ) llm_resp = await self.providers.llm.aget_completion( messages=messages, generation_config=generation_config, ) llm_text = llm_resp.choices[0].message.content or "" # find ... XML match = re.search( r".*?", llm_text, re.DOTALL ) if not match: raise ValueError( "No XML found in LLM response" ) xml_content = re.sub( r"&(?!amp;|quot;|apos;|lt;|gt;)", "&", match.group(0) ).strip() root = ET.fromstring(xml_content) # extract fields name_elem = root.find("name") summary_elem = root.find("summary") rating_elem = root.find("rating") rating_expl_elem = root.find("rating_explanation") findings_elem = root.find("findings") name = name_elem.text if name_elem is not None else "" summary = summary_elem.text if summary_elem is not None else "" rating = ( float(rating_elem.text) if isinstance(rating_elem, Element) and rating_elem.text else "" ) rating_explanation = ( rating_expl_elem.text if rating_expl_elem is not None else None ) findings = ( [f.text for f in findings_elem.findall("finding")] if findings_elem is not None else [] ) # build embedding embed_text = ( "Summary:\n" + (summary or "") + "\n\nFindings:\n" + "\n".join( finding for finding in findings if finding is not None ) ) embedding = await self.providers.embedding.async_get_embedding( embed_text ) # build Community object community = Community( community_id=community_id, collection_id=collection_id, name=name, summary=summary, rating=rating, rating_explanation=rating_explanation, findings=findings, description_embedding=embedding, ) # store it await self.providers.database.graphs_handler.add_community( community ) return { "community_id": community_id, "name": name, } except Exception as e: logger.error( f"Error summarizing community {community_id}: {e}" ) if attempt == 2: return {"community_id": community_id, "error": str(e)} await asyncio.sleep(1) # fallback return {"community_id": community_id, "error": "Failed after retries"} async def _community_summary_prompt( self, entities: list[Entity], relationships: list[Relationship], max_summary_input_length: int, ) -> str: """Gathers the entity/relationship text, tries not to exceed `max_summary_input_length`.""" # Group them by entity.name entity_map: dict[str, dict] = {} for e in entities: entity_map.setdefault( e.name, {"entities": [], "relationships": []} ) entity_map[e.name]["entities"].append(e) for r in relationships: # subject entity_map.setdefault( r.subject, {"entities": [], "relationships": []} ) entity_map[r.subject]["relationships"].append(r) # sort by # of relationships sorted_entries = sorted( entity_map.items(), key=lambda x: len(x[1]["relationships"]), reverse=True, ) # build up the prompt text prompt_chunks = [] cur_len = 0 for entity_name, data in sorted_entries: block = f"\nEntity: {entity_name}\nDescriptions:\n" block += "\n".join( f"{e.id},{(e.description or '')}" for e in data["entities"] ) block += "\nRelationships:\n" block += "\n".join( f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}" for r in data["relationships"] ) # check length if cur_len + len(block) > max_summary_input_length: prompt_chunks.append( block[: max_summary_input_length - cur_len] ) break else: prompt_chunks.append(block) cur_len += len(block) return "".join(prompt_chunks) async def delete( self, collection_id: UUID, **kwargs, ): return await self.providers.database.graphs_handler.delete( collection_id=collection_id, ) async def graph_search_results_extraction( self, document_id: UUID, generation_config: GenerationConfig, entity_types: list[str], relation_types: list[str], chunk_merge_count: int, filter_out_existing_chunks: bool = True, total_tasks: Optional[int] = None, *args: Any, **kwargs: Any, ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]: """The original “extract Graph from doc” logic, but inlined instead of referencing a pipe.""" start_time = time.time() logger.info( f"Graph Extraction: Processing document {document_id} for graph extraction" ) # Retrieve chunks from DB chunks = [] limit = 100 offset = 0 while True: chunk_req = await self.providers.database.chunks_handler.list_document_chunks( document_id=document_id, offset=offset, limit=limit, ) new_chunk_objs = [ DocumentChunk( id=chunk["id"], document_id=chunk["document_id"], owner_id=chunk["owner_id"], collection_ids=chunk["collection_ids"], data=chunk["text"], metadata=chunk["metadata"], ) for chunk in chunk_req["results"] ] chunks.extend(new_chunk_objs) if len(chunk_req["results"]) < limit: break offset += limit if not chunks: logger.info(f"No chunks found for document {document_id}") raise R2RException( message="No chunks found for document", status_code=404, ) # Possibly filter out any chunks that have already been processed if filter_out_existing_chunks: existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids( document_id=document_id ) before_count = len(chunks) chunks = [c for c in chunks if c.id not in existing_chunk_ids] logger.info( f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain." ) if not chunks: return # nothing left to yield # sort by chunk_order if present chunks = sorted( chunks, key=lambda x: x.metadata.get("chunk_order", float("inf")), ) # group them grouped_chunks = [ chunks[i : i + chunk_merge_count] for i in range(0, len(chunks), chunk_merge_count) ] logger.info( f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}" ) tasks = [ asyncio.create_task( self._extract_graph_search_results_from_chunk_group( chunk_group, generation_config, entity_types, relation_types, ) ) for chunk_group in grouped_chunks ] completed_tasks = 0 for t in asyncio.as_completed(tasks): try: yield await t completed_tasks += 1 if completed_tasks % 100 == 0: logger.info( f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks" ) except Exception as e: logger.error(f"Error extracting from chunk group: {e}") yield R2RDocumentProcessingError( document_id=document_id, error_message=str(e), ) logger.info( f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s" ) async def _extract_graph_search_results_from_chunk_group( self, chunks: list[DocumentChunk], generation_config: GenerationConfig, entity_types: list[str], relation_types: list[str], retries: int = 5, delay: int = 2, ) -> GraphExtraction: """(Equivalent to _extract_graph_search_results in old code.) Merges chunk data, calls LLM, parses XML, returns GraphExtraction object.""" combined_extraction: str = " ".join( [ c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data for c in chunks if c.data ] ) # Possibly get doc-level summary doc_id = chunks[0].document_id response = await self.providers.database.documents_handler.get_documents_overview( offset=0, limit=1, filter_document_ids=[doc_id], ) document_summary = ( response["results"][0].summary if response["results"] else None ) # Build messages/prompt prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt messages = ( await self.providers.database.prompts_handler.get_message_payload( task_prompt_name=prompt_name, task_inputs={ "document_summary": document_summary or "", "input": combined_extraction, "entity_types": "\n".join(entity_types), "relation_types": "\n".join(relation_types), }, ) ) for attempt in range(retries): try: resp = await self.providers.llm.aget_completion( messages, generation_config=generation_config ) graph_search_results_str = resp.choices[0].message.content if not graph_search_results_str: raise R2RException( "No extraction found in LLM response.", 400, ) logger.info(generation_config) logger.info(graph_search_results_str) # parse the XML ( entities, relationships, ) = await self._parse_graph_search_results_extraction_xml( graph_search_results_str, chunks ) return GraphExtraction( entities=entities, relationships=relationships ) except Exception as e: if attempt < retries - 1: await asyncio.sleep(delay) continue else: logger.error( f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}" ) return GraphExtraction(entities=[], relationships=[]) return GraphExtraction(entities=[], relationships=[]) async def _parse_graph_search_results_extraction_xml( self, response_str: str, chunks: list[DocumentChunk] ) -> tuple[list[Entity], list[Relationship]]: """Helper to parse the LLM's XML format, handle edge cases/cleanup, produce Entities/Relationships.""" def sanitize_xml(r: str) -> str: # Remove markdown fences r = re.sub(r"```xml|```", "", r) # Remove xml instructions or userStyle r = re.sub(r"<\?.*?\?>", "", r) r = re.sub(r".*?", "", r) # Replace bare `&` with `&` r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&", r) # Also remove if it appears r = r.replace("", "").replace("", "") return r.strip() cleaned_xml = sanitize_xml(response_str) wrapped = f"{cleaned_xml}" try: root = ET.fromstring(wrapped, parser=ET.XMLParser(encoding="utf-8")) except ET.ParseError: raise R2RException( f"Failed to parse XML:\nData: {wrapped}", 400 ) from None entities_elems = root.findall(".//entity") if ( len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH and len(entities_elems) == 0 ): raise R2RException( f"No found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}", 400, ) # build entity objects doc_id = chunks[0].document_id chunk_ids = [c.id for c in chunks] entities_list: list[Entity] = [] for element in entities_elems: name_attr = element.get("name") type_elem = element.find("type") desc_elem = element.find("description") category = type_elem.text if type_elem is not None else None desc = desc_elem.text if desc_elem is not None else None desc_embed = await self.providers.embedding.async_get_embedding( desc or "" ) ent = Entity( category=category, description=desc, name=name_attr, parent_id=doc_id, chunk_ids=chunk_ids, description_embedding=desc_embed, attributes={}, ) entities_list.append(ent) # build relationship objects relationships_list: list[Relationship] = [] rel_elems = root.findall(".//relationship") for r_elem in rel_elems: source_elem = r_elem.find("source") target_elem = r_elem.find("target") type_elem = r_elem.find("type") desc_elem = r_elem.find("description") weight_elem = r_elem.find("weight") try: subject = source_elem.text if source_elem is not None else "" object_ = target_elem.text if target_elem is not None else "" predicate = type_elem.text if type_elem is not None else "" desc = desc_elem.text if desc_elem is not None else "" weight = ( float(weight_elem.text) if isinstance(weight_elem, Element) and weight_elem.text else "" ) embed = await self.providers.embedding.async_get_embedding( desc or "" ) rel = Relationship( subject=subject, predicate=predicate, object=object_, description=desc, weight=weight, parent_id=doc_id, chunk_ids=chunk_ids, attributes={}, description_embedding=embed, ) relationships_list.append(rel) except Exception: continue return entities_list, relationships_list async def store_graph_search_results_extractions( self, graph_search_results_extractions: list[GraphExtraction], ): """Stores a batch of knowledge graph extractions in the DB.""" for extraction in graph_search_results_extractions: # Map name->id after creation entities_id_map = {} for e in extraction.entities: if e.parent_id is not None: result = await self.providers.database.graphs_handler.entities.create( name=e.name, parent_id=e.parent_id, store_type=StoreType.DOCUMENTS, category=e.category, description=e.description, description_embedding=e.description_embedding, chunk_ids=e.chunk_ids, metadata=e.metadata, ) entities_id_map[e.name] = result.id else: logger.warning(f"Skipping entity with None parent_id: {e}") # Insert relationships for rel in extraction.relationships: subject_id = entities_id_map.get(rel.subject) object_id = entities_id_map.get(rel.object) parent_id = rel.parent_id if any( id is None for id in (subject_id, object_id, parent_id) ): logger.warning(f"Missing ID for relationship: {rel}") continue assert isinstance(subject_id, UUID) assert isinstance(object_id, UUID) assert isinstance(parent_id, UUID) await self.providers.database.graphs_handler.relationships.create( subject=rel.subject, subject_id=subject_id, predicate=rel.predicate, object=rel.object, object_id=object_id, parent_id=parent_id, description=rel.description, description_embedding=rel.description_embedding, weight=rel.weight, metadata=rel.metadata, store_type=StoreType.DOCUMENTS, ) async def deduplicate_document_entities( self, document_id: UUID, ): """ Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record. """ merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks( parent_id=document_id, store_type=StoreType.DOCUMENTS, ) # Grab doc summary response = await self.providers.database.documents_handler.get_documents_overview( offset=0, limit=1, filter_document_ids=[document_id], ) document_summary = ( response["results"][0].summary if response["results"] else None ) # For each merged entity for original_entities, merged_entity in merged_results: # Summarize them with LLM entity_info = "\n".join( e.description for e in original_entities if e.description ) messages = await self.providers.database.prompts_handler.get_message_payload( task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt, task_inputs={ "document_summary": document_summary, "entity_info": f"{merged_entity.name}\n{entity_info}", "relationships_txt": "", }, ) gen_config = ( self.config.database.graph_creation_settings.generation_config or GenerationConfig(model=self.config.app.fast_llm) ) resp = await self.providers.llm.aget_completion( messages, generation_config=gen_config ) new_description = resp.choices[0].message.content new_embedding = await self.providers.embedding.async_get_embedding( new_description or "" ) if merged_entity.id is not None: await self.providers.database.graphs_handler.entities.update( entity_id=merged_entity.id, store_type=StoreType.DOCUMENTS, description=new_description, description_embedding=str(new_embedding), ) else: logger.warning("Skipping update for entity with None id")