import asyncio import json import logging import math import re import time from typing import Any, AsyncGenerator, Optional from uuid import UUID from core.base import ( DocumentChunk, KGExtraction, KGExtractionStatus, R2RDocumentProcessingError, RunManager, ) from core.base.abstractions import ( Community, Entity, GenerationConfig, KGCreationSettings, KGEnrichmentSettings, KGEnrichmentStatus, KGEntityDeduplicationSettings, KGEntityDeduplicationType, R2RException, Relationship, ) from core.base.api.models import GraphResponse from core.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders from ..config import R2RConfig from .base import Service logger = logging.getLogger() MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128 async def _collect_results(result_gen: AsyncGenerator) -> list[dict]: results = [] async for res in result_gen: results.append(res.json() if hasattr(res, "json") else res) return results # TODO - Fix naming convention to read `KGService` instead of `GraphService` # this will require a minor change in how services are registered. class GraphService(Service): def __init__( self, config: R2RConfig, providers: R2RProviders, pipes: R2RPipes, pipelines: R2RPipelines, agents: R2RAgents, run_manager: RunManager, ): super().__init__( config, providers, pipes, pipelines, agents, run_manager, ) @telemetry_event("kg_relationships_extraction") async def kg_relationships_extraction( self, document_id: UUID, generation_config: GenerationConfig, chunk_merge_count: int, max_knowledge_relationships: int, entity_types: list[str], relation_types: list[str], **kwargs, ): try: logger.info( f"KGService: Processing document {document_id} for KG extraction" ) await self.providers.database.documents_handler.set_workflow_status( id=document_id, status_type="extraction_status", status=KGExtractionStatus.PROCESSING, ) relationships = await self.pipes.graph_extraction_pipe.run( input=self.pipes.graph_extraction_pipe.Input( message={ "document_id": document_id, "generation_config": generation_config, "chunk_merge_count": chunk_merge_count, "max_knowledge_relationships": max_knowledge_relationships, "entity_types": entity_types, "relation_types": relation_types, "logger": logger, } ), state=None, run_manager=self.run_manager, ) logger.info( f"KGService: Finished processing document {document_id} for KG extraction" ) result_gen = await self.pipes.graph_storage_pipe.run( input=self.pipes.graph_storage_pipe.Input( message=relationships ), state=None, run_manager=self.run_manager, ) except Exception as e: logger.error(f"KGService: Error in kg_extraction: {e}") await self.providers.database.documents_handler.set_workflow_status( id=document_id, status_type="extraction_status", status=KGExtractionStatus.FAILED, ) raise e return await _collect_results(result_gen) @telemetry_event("create_entity") async def create_entity( self, name: str, description: str, parent_id: UUID, category: Optional[str] = None, metadata: Optional[dict] = None, ) -> Entity: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return await self.providers.database.graphs_handler.entities.create( name=name, parent_id=parent_id, store_type="graphs", # type: ignore category=category, description=description, description_embedding=description_embedding, metadata=metadata, ) @telemetry_event("update_entity") async def update_entity( self, entity_id: UUID, name: Optional[str] = None, description: Optional[str] = None, category: Optional[str] = None, metadata: Optional[dict] = None, ) -> Entity: description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return await self.providers.database.graphs_handler.entities.update( entity_id=entity_id, store_type="graphs", # type: ignore name=name, description=description, description_embedding=description_embedding, category=category, metadata=metadata, ) @telemetry_event("delete_entity") async def delete_entity( self, parent_id: UUID, entity_id: UUID, ): return await self.providers.database.graphs_handler.entities.delete( parent_id=parent_id, entity_ids=[entity_id], store_type="graphs", # type: ignore ) @telemetry_event("get_entities") async def get_entities( self, parent_id: UUID, offset: int, limit: int, entity_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, include_embeddings: bool = False, ): return await self.providers.database.graphs_handler.get_entities( parent_id=parent_id, offset=offset, limit=limit, entity_ids=entity_ids, entity_names=entity_names, include_embeddings=include_embeddings, ) @telemetry_event("create_relationship") async def create_relationship( self, subject: str, subject_id: UUID, predicate: str, object: str, object_id: UUID, parent_id: UUID, description: str | None = None, weight: float | None = 1.0, metadata: Optional[dict[str, Any] | str] = None, ) -> Relationship: description_embedding = None if description: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return ( await self.providers.database.graphs_handler.relationships.create( subject=subject, subject_id=subject_id, predicate=predicate, object=object, object_id=object_id, parent_id=parent_id, description=description, description_embedding=description_embedding, weight=weight, metadata=metadata, store_type="graphs", # type: ignore ) ) @telemetry_event("delete_relationship") async def delete_relationship( self, parent_id: UUID, relationship_id: UUID, ): return ( await self.providers.database.graphs_handler.relationships.delete( parent_id=parent_id, relationship_ids=[relationship_id], store_type="graphs", # type: ignore ) ) @telemetry_event("update_relationship") async def update_relationship( self, relationship_id: UUID, subject: Optional[str] = None, subject_id: Optional[UUID] = None, predicate: Optional[str] = None, object: Optional[str] = None, object_id: Optional[UUID] = None, description: Optional[str] = None, weight: Optional[float] = None, metadata: Optional[dict[str, Any] | str] = None, ) -> Relationship: description_embedding = None if description is not None: description_embedding = str( await self.providers.embedding.async_get_embedding(description) ) return ( await self.providers.database.graphs_handler.relationships.update( relationship_id=relationship_id, subject=subject, subject_id=subject_id, predicate=predicate, object=object, object_id=object_id, description=description, description_embedding=description_embedding, weight=weight, metadata=metadata, store_type="graphs", # type: ignore ) ) @telemetry_event("get_relationships") async def get_relationships( self, parent_id: UUID, offset: int, limit: int, relationship_ids: Optional[list[UUID]] = None, entity_names: Optional[list[str]] = None, ): return await self.providers.database.graphs_handler.relationships.get( parent_id=parent_id, store_type="graphs", # type: ignore offset=offset, limit=limit, relationship_ids=relationship_ids, entity_names=entity_names, ) @telemetry_event("create_community") async def create_community( self, parent_id: UUID, name: str, summary: str, findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], ) -> Community: description_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graphs_handler.communities.create( parent_id=parent_id, store_type="graphs", # type: ignore name=name, summary=summary, description_embedding=description_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, ) @telemetry_event("update_community") async def update_community( self, community_id: UUID, name: Optional[str], summary: Optional[str], findings: Optional[list[str]], rating: Optional[float], rating_explanation: Optional[str], ) -> Community: summary_embedding = None if summary is not None: summary_embedding = str( await self.providers.embedding.async_get_embedding(summary) ) return await self.providers.database.graphs_handler.communities.update( community_id=community_id, store_type="graphs", # type: ignore name=name, summary=summary, summary_embedding=summary_embedding, findings=findings, rating=rating, rating_explanation=rating_explanation, ) @telemetry_event("delete_community") async def delete_community( self, parent_id: UUID, community_id: UUID, ) -> None: await self.providers.database.graphs_handler.communities.delete( parent_id=parent_id, community_id=community_id, ) @telemetry_event("list_communities") async def list_communities( self, collection_id: UUID, offset: int, limit: int, ): return await self.providers.database.graphs_handler.communities.get( parent_id=collection_id, store_type="graphs", # type: ignore offset=offset, limit=limit, ) @telemetry_event("get_communities") async def get_communities( self, parent_id: UUID, offset: int, limit: int, community_ids: Optional[list[UUID]] = None, community_names: Optional[list[str]] = None, include_embeddings: bool = False, ): return await self.providers.database.graphs_handler.get_communities( parent_id=parent_id, offset=offset, limit=limit, community_ids=community_ids, include_embeddings=include_embeddings, ) # @telemetry_event("create_new_graph") # async def create_new_graph( # self, # collection_id: UUID, # user_id: UUID, # name: Optional[str], # description: str = "", # ) -> GraphResponse: # return await self.providers.database.graphs_handler.create( # collection_id=collection_id, # user_id=user_id, # name=name, # description=description, # graph_id=collection_id, # ) async def list_graphs( self, offset: int, limit: int, # user_ids: Optional[list[UUID]] = None, graph_ids: Optional[list[UUID]] = None, collection_id: Optional[UUID] = None, ) -> dict[str, list[GraphResponse] | int]: return await self.providers.database.graphs_handler.list_graphs( offset=offset, limit=limit, # filter_user_ids=user_ids, filter_graph_ids=graph_ids, filter_collection_id=collection_id, ) @telemetry_event("update_graph") async def update_graph( self, collection_id: UUID, name: Optional[str] = None, description: Optional[str] = None, ) -> GraphResponse: return await self.providers.database.graphs_handler.update( collection_id=collection_id, name=name, description=description, ) @telemetry_event("reset_graph_v3") async def reset_graph_v3(self, id: UUID) -> bool: await self.providers.database.graphs_handler.reset( parent_id=id, ) await self.providers.database.documents_handler.set_workflow_status( id=id, status_type="graph_cluster_status", status=KGEnrichmentStatus.PENDING, ) return True @telemetry_event("get_document_ids_for_create_graph") async def get_document_ids_for_create_graph( self, collection_id: UUID, force_kg_creation: bool = False, **kwargs, ): document_status_filter = [ KGExtractionStatus.PENDING, KGExtractionStatus.FAILED, ] if force_kg_creation: document_status_filter += [ KGExtractionStatus.PROCESSING, ] return await self.providers.database.documents_handler.get_document_ids_by_status( status_type="extraction_status", status=[str(ele) for ele in document_status_filter], collection_id=collection_id, ) @telemetry_event("kg_entity_description") async def kg_entity_description( self, document_id: UUID, max_description_input_length: int, **kwargs, ): start_time = time.time() logger.info( f"KGService: Running kg_entity_description for document {document_id}" ) entity_count = ( await self.providers.database.graphs_handler.get_entity_count( document_id=document_id, distinct=True, entity_table_name="documents_entities", ) ) logger.info( f"KGService: Found {entity_count} entities in document {document_id}" ) # TODO - Do not hardcode the batch size, # make it a configurable parameter at runtime & server-side defaults # process 256 entities at a time num_batches = math.ceil(entity_count / 256) logger.info( f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}" ) all_results = [] for i in range(num_batches): logger.info( f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" ) node_descriptions = await self.pipes.graph_description_pipe.run( input=self.pipes.graph_description_pipe.Input( message={ "offset": i * 256, "limit": 256, "max_description_input_length": max_description_input_length, "document_id": document_id, "logger": logger, } ), state=None, run_manager=self.run_manager, ) all_results.append(await _collect_results(node_descriptions)) logger.info( f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}" ) await self.providers.database.documents_handler.set_workflow_status( id=document_id, status_type="extraction_status", status=KGExtractionStatus.SUCCESS, ) logger.info( f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds", ) return all_results @telemetry_event("kg_clustering") async def kg_clustering( self, collection_id: UUID, # graph_id: UUID, generation_config: GenerationConfig, leiden_params: dict, **kwargs, ): logger.info( f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}" ) clustering_result = await self.pipes.graph_clustering_pipe.run( input=self.pipes.graph_clustering_pipe.Input( message={ "collection_id": collection_id, "generation_config": generation_config, "leiden_params": leiden_params, "logger": logger, "clustering_mode": self.config.database.graph_creation_settings.clustering_mode, } ), state=None, run_manager=self.run_manager, ) return await _collect_results(clustering_result) @telemetry_event("kg_community_summary") async def kg_community_summary( self, offset: int, limit: int, max_summary_input_length: int, generation_config: GenerationConfig, collection_id: UUID | None, # graph_id: UUID | None, **kwargs, ): summary_results = await self.pipes.graph_community_summary_pipe.run( input=self.pipes.graph_community_summary_pipe.Input( message={ "offset": offset, "limit": limit, "generation_config": generation_config, "max_summary_input_length": max_summary_input_length, "collection_id": collection_id, # "graph_id": graph_id, "logger": logger, } ), state=None, run_manager=self.run_manager, ) return await _collect_results(summary_results) @telemetry_event("delete_graph_for_documents") async def delete_graph_for_documents( self, document_ids: list[UUID], **kwargs, ): # TODO: Implement this, as it needs some checks. raise NotImplementedError @telemetry_event("delete_graph") async def delete_graph( self, collection_id: UUID, cascade: bool, **kwargs, ): return await self.delete(collection_id=collection_id, cascade=cascade) @telemetry_event("delete") async def delete( self, collection_id: UUID, cascade: bool, **kwargs, ): return await self.providers.database.graphs_handler.delete( collection_id=collection_id, cascade=cascade, ) @telemetry_event("get_creation_estimate") async def get_creation_estimate( self, graph_creation_settings: KGCreationSettings, document_id: Optional[UUID] = None, collection_id: Optional[UUID] = None, **kwargs, ): return ( await self.providers.database.graphs_handler.get_creation_estimate( document_id=document_id, collection_id=collection_id, graph_creation_settings=graph_creation_settings, ) ) @telemetry_event("get_enrichment_estimate") async def get_enrichment_estimate( self, collection_id: Optional[UUID] = None, graph_id: Optional[UUID] = None, graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(), **kwargs, ): if graph_id is None and collection_id is None: raise ValueError( "Either graph_id or collection_id must be provided" ) return await self.providers.database.graphs_handler.get_enrichment_estimate( collection_id=collection_id, graph_id=graph_id, graph_enrichment_settings=graph_enrichment_settings, ) @telemetry_event("get_deduplication_estimate") async def get_deduplication_estimate( self, collection_id: UUID, kg_deduplication_settings: KGEntityDeduplicationSettings, **kwargs, ): return await self.providers.database.graphs_handler.get_deduplication_estimate( collection_id=collection_id, kg_deduplication_settings=kg_deduplication_settings, ) @telemetry_event("kg_entity_deduplication") async def kg_entity_deduplication( self, collection_id: UUID, graph_id: UUID, graph_entity_deduplication_type: KGEntityDeduplicationType, graph_entity_deduplication_prompt: str, generation_config: GenerationConfig, **kwargs, ): deduplication_results = await self.pipes.graph_deduplication_pipe.run( input=self.pipes.graph_deduplication_pipe.Input( message={ "collection_id": collection_id, "graph_id": graph_id, "graph_entity_deduplication_type": graph_entity_deduplication_type, "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt, "generation_config": generation_config, **kwargs, } ), state=None, run_manager=self.run_manager, ) return await _collect_results(deduplication_results) @telemetry_event("kg_entity_deduplication_summary") async def kg_entity_deduplication_summary( self, collection_id: UUID, offset: int, limit: int, graph_entity_deduplication_type: KGEntityDeduplicationType, graph_entity_deduplication_prompt: str, generation_config: GenerationConfig, **kwargs, ): logger.info( f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}" ) deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run( input=self.pipes.graph_deduplication_summary_pipe.Input( message={ "collection_id": collection_id, "offset": offset, "limit": limit, "graph_entity_deduplication_type": graph_entity_deduplication_type, "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt, "generation_config": generation_config, } ), state=None, run_manager=self.run_manager, ) return await _collect_results(deduplication_summary_results) async def kg_extraction( # type: ignore self, document_id: UUID, generation_config: GenerationConfig, max_knowledge_relationships: int, entity_types: list[str], relation_types: list[str], chunk_merge_count: int, filter_out_existing_chunks: bool = True, total_tasks: Optional[int] = None, *args: Any, **kwargs: Any, ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]: start_time = time.time() logger.info( f"GraphExtractionPipe: Processing document {document_id} for KG extraction", ) # Then create the extractions from the results limit = 100 offset = 0 chunks = [] while True: chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. document_id=document_id, offset=offset, limit=limit, ) chunks.extend( [ DocumentChunk( id=chunk["id"], document_id=chunk["document_id"], owner_id=chunk["owner_id"], collection_ids=chunk["collection_ids"], data=chunk["text"], metadata=chunk["metadata"], ) for chunk in chunk_req["results"] ] ) if len(chunk_req["results"]) < limit: break offset += limit logger.info(f"Found {len(chunks)} chunks for document {document_id}") if len(chunks) == 0: logger.info(f"No chunks found for document {document_id}") raise R2RException( message="No chunks found for document", status_code=404, ) if filter_out_existing_chunks: existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids( document_id=document_id ) chunks = [ chunk for chunk in chunks if chunk.id not in existing_chunk_ids ] logger.info( f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}" ) if len(chunks) == 0: logger.info(f"No extractions left for document {document_id}") return logger.info( f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds", ) # sort the extractions accroding to chunk_order field in metadata in ascending order chunks = sorted( chunks, key=lambda x: x.metadata.get("chunk_order", float("inf")), ) # group these extractions into groups of chunk_merge_count grouped_chunks = [ chunks[i : i + chunk_merge_count] for i in range(0, len(chunks), chunk_merge_count) ] logger.info( f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds", ) tasks = [ asyncio.create_task( self._extract_kg( chunks=chunk_group, generation_config=generation_config, max_knowledge_relationships=max_knowledge_relationships, entity_types=entity_types, relation_types=relation_types, task_id=task_id, total_tasks=len(grouped_chunks), ) ) for task_id, chunk_group in enumerate(grouped_chunks) ] completed_tasks = 0 total_tasks = len(tasks) logger.info( f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete", ) for completed_task in asyncio.as_completed(tasks): try: yield await completed_task completed_tasks += 1 if completed_tasks % 100 == 0: logger.info( f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks", ) except Exception as e: logger.error(f"Error in Extracting KG Relationships: {e}") yield R2RDocumentProcessingError( document_id=document_id, error_message=str(e), ) logger.info( f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds", ) async def _extract_kg( self, chunks: list[DocumentChunk], generation_config: GenerationConfig, max_knowledge_relationships: int, entity_types: list[str], relation_types: list[str], retries: int = 5, delay: int = 2, task_id: Optional[int] = None, total_tasks: Optional[int] = None, ) -> KGExtraction: """ Extracts NER relationships from a extraction with retries. """ # combine all extractions into a single string combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore offset=0, limit=1, filter_document_ids=[chunks[0].document_id], ) document_summary = ( response["results"][0].summary if response["results"] else None ) messages = await self.providers.database.prompts_handler.get_message_payload( task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot, task_inputs={ "document_summary": document_summary, "input": combined_extraction, "max_knowledge_relationships": max_knowledge_relationships, "entity_types": "\n".join(entity_types), "relation_types": "\n".join(relation_types), }, ) for attempt in range(retries): try: response = await self.providers.llm.aget_completion( messages, generation_config=generation_config, ) kg_extraction = response.choices[0].message.content if not kg_extraction: raise R2RException( "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.", 400, ) entity_pattern = ( r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)' ) relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)' async def parse_fn(response_str: str) -> Any: entities = re.findall(entity_pattern, response_str) if ( len(kg_extraction) > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH and len(entities) == 0 ): raise R2RException( f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}", 400, ) relationships = re.findall( relationship_pattern, response_str ) entities_arr = [] for entity in entities: entity_value = entity[0] entity_category = entity[1] entity_description = entity[2] description_embedding = ( await self.providers.embedding.async_get_embedding( entity_description ) ) entities_arr.append( Entity( category=entity_category, description=entity_description, name=entity_value, parent_id=chunks[0].document_id, chunk_ids=[chunk.id for chunk in chunks], description_embedding=description_embedding, attributes={}, ) ) relations_arr = [] for relationship in relationships: subject = relationship[0] object = relationship[1] predicate = relationship[2] description = relationship[3] weight = float(relationship[4]) relationship_embedding = ( await self.providers.embedding.async_get_embedding( description ) ) # check if subject and object are in entities_dict relations_arr.append( Relationship( subject=subject, predicate=predicate, object=object, description=description, weight=weight, parent_id=chunks[0].document_id, chunk_ids=[chunk.id for chunk in chunks], attributes={}, description_embedding=relationship_embedding, ) ) return entities_arr, relations_arr entities, relationships = await parse_fn(kg_extraction) return KGExtraction( entities=entities, relationships=relationships, ) except ( Exception, json.JSONDecodeError, KeyError, IndexError, R2RException, ) as e: if attempt < retries - 1: await asyncio.sleep(delay) else: logger.warning( f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}" ) logger.info( f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}", ) return KGExtraction( entities=[], relationships=[], ) async def store_kg_extractions( self, kg_extractions: list[KGExtraction], ): """ Stores a batch of knowledge graph extractions in the graph database. """ for extraction in kg_extractions: entities_id_map = {} for entity in extraction.entities: result = await self.providers.database.graphs_handler.entities.create( name=entity.name, parent_id=entity.parent_id, store_type="documents", # type: ignore category=entity.category, description=entity.description, description_embedding=entity.description_embedding, chunk_ids=entity.chunk_ids, metadata=entity.metadata, ) entities_id_map[entity.name] = result.id if extraction.relationships: for relationship in extraction.relationships: await self.providers.database.graphs_handler.relationships.create( subject=relationship.subject, subject_id=entities_id_map.get(relationship.subject), predicate=relationship.predicate, object=relationship.object, object_id=entities_id_map.get(relationship.object), parent_id=relationship.parent_id, description=relationship.description, description_embedding=relationship.description_embedding, weight=relationship.weight, metadata=relationship.metadata, store_type="documents", # type: ignore )