123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997 |
- import asyncio
- import json
- import logging
- import math
- import re
- import time
- from typing import Any, AsyncGenerator, Optional
- from uuid import UUID
- from core.base import (
- DocumentChunk,
- KGExtraction,
- KGExtractionStatus,
- R2RDocumentProcessingError,
- RunManager,
- )
- from core.base.abstractions import (
- Community,
- Entity,
- GenerationConfig,
- KGCreationSettings,
- KGEnrichmentSettings,
- KGEnrichmentStatus,
- R2RException,
- Relationship,
- StoreType,
- )
- from core.base.api.models import GraphResponse
- from core.telemetry.telemetry_decorator import telemetry_event
- from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
- from ..config import R2RConfig
- from .base import Service
- logger = logging.getLogger()
- MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
- async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
- results = []
- async for res in result_gen:
- results.append(res.json() if hasattr(res, "json") else res)
- return results
- # TODO - Fix naming convention to read `KGService` instead of `GraphService`
- # this will require a minor change in how services are registered.
- class GraphService(Service):
- def __init__(
- self,
- config: R2RConfig,
- providers: R2RProviders,
- pipes: R2RPipes,
- pipelines: R2RPipelines,
- agents: R2RAgents,
- run_manager: RunManager,
- ):
- super().__init__(
- config,
- providers,
- pipes,
- pipelines,
- agents,
- run_manager,
- )
- @telemetry_event("kg_relationships_extraction")
- async def kg_relationships_extraction(
- self,
- document_id: UUID,
- generation_config: GenerationConfig,
- chunk_merge_count: int,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- **kwargs,
- ):
- try:
- logger.info(
- f"KGService: Processing document {document_id} for KG extraction"
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.PROCESSING,
- )
- relationships = await self.pipes.graph_extraction_pipe.run(
- input=self.pipes.graph_extraction_pipe.Input(
- message={
- "document_id": document_id,
- "generation_config": generation_config,
- "chunk_merge_count": chunk_merge_count,
- "max_knowledge_relationships": max_knowledge_relationships,
- "entity_types": entity_types,
- "relation_types": relation_types,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- logger.info(
- f"KGService: Finished processing document {document_id} for KG extraction"
- )
- result_gen = await self.pipes.graph_storage_pipe.run(
- input=self.pipes.graph_storage_pipe.Input(
- message=relationships
- ),
- state=None,
- run_manager=self.run_manager,
- )
- except Exception as e:
- logger.error(f"KGService: Error in kg_extraction: {e}")
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.FAILED,
- )
- raise e
- return await _collect_results(result_gen)
- @telemetry_event("create_entity")
- async def create_entity(
- self,
- name: str,
- description: str,
- parent_id: UUID,
- category: Optional[str] = None,
- metadata: Optional[dict] = None,
- ) -> Entity:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return await self.providers.database.graphs_handler.entities.create(
- name=name,
- parent_id=parent_id,
- store_type=StoreType.GRAPHS,
- category=category,
- description=description,
- description_embedding=description_embedding,
- metadata=metadata,
- )
- @telemetry_event("update_entity")
- async def update_entity(
- self,
- entity_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- category: Optional[str] = None,
- metadata: Optional[dict] = None,
- ) -> Entity:
- description_embedding = None
- if description is not None:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return await self.providers.database.graphs_handler.entities.update(
- entity_id=entity_id,
- store_type=StoreType.GRAPHS,
- name=name,
- description=description,
- description_embedding=description_embedding,
- category=category,
- metadata=metadata,
- )
- @telemetry_event("delete_entity")
- async def delete_entity(
- self,
- parent_id: UUID,
- entity_id: UUID,
- ):
- return await self.providers.database.graphs_handler.entities.delete(
- parent_id=parent_id,
- entity_ids=[entity_id],
- store_type=StoreType.GRAPHS,
- )
- @telemetry_event("get_entities")
- async def get_entities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- entity_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- return await self.providers.database.graphs_handler.get_entities(
- parent_id=parent_id,
- offset=offset,
- limit=limit,
- entity_ids=entity_ids,
- entity_names=entity_names,
- include_embeddings=include_embeddings,
- )
- @telemetry_event("create_relationship")
- async def create_relationship(
- self,
- subject: str,
- subject_id: UUID,
- predicate: str,
- object: str,
- object_id: UUID,
- parent_id: UUID,
- description: str | None = None,
- weight: float | None = 1.0,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Relationship:
- description_embedding = None
- if description:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return (
- await self.providers.database.graphs_handler.relationships.create(
- subject=subject,
- subject_id=subject_id,
- predicate=predicate,
- object=object,
- object_id=object_id,
- parent_id=parent_id,
- description=description,
- description_embedding=description_embedding,
- weight=weight,
- metadata=metadata,
- store_type=StoreType.GRAPHS,
- )
- )
- @telemetry_event("delete_relationship")
- async def delete_relationship(
- self,
- parent_id: UUID,
- relationship_id: UUID,
- ):
- return (
- await self.providers.database.graphs_handler.relationships.delete(
- parent_id=parent_id,
- relationship_ids=[relationship_id],
- store_type=StoreType.GRAPHS,
- )
- )
- @telemetry_event("update_relationship")
- async def update_relationship(
- self,
- relationship_id: UUID,
- subject: Optional[str] = None,
- subject_id: Optional[UUID] = None,
- predicate: Optional[str] = None,
- object: Optional[str] = None,
- object_id: Optional[UUID] = None,
- description: Optional[str] = None,
- weight: Optional[float] = None,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Relationship:
- description_embedding = None
- if description is not None:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return (
- await self.providers.database.graphs_handler.relationships.update(
- relationship_id=relationship_id,
- subject=subject,
- subject_id=subject_id,
- predicate=predicate,
- object=object,
- object_id=object_id,
- description=description,
- description_embedding=description_embedding,
- weight=weight,
- metadata=metadata,
- store_type=StoreType.GRAPHS,
- )
- )
- @telemetry_event("get_relationships")
- async def get_relationships(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- relationship_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- ):
- return await self.providers.database.graphs_handler.relationships.get(
- parent_id=parent_id,
- store_type=StoreType.GRAPHS,
- offset=offset,
- limit=limit,
- relationship_ids=relationship_ids,
- entity_names=entity_names,
- )
- @telemetry_event("create_community")
- async def create_community(
- self,
- parent_id: UUID,
- name: str,
- summary: str,
- findings: Optional[list[str]],
- rating: Optional[float],
- rating_explanation: Optional[str],
- ) -> Community:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(summary)
- )
- return await self.providers.database.graphs_handler.communities.create(
- parent_id=parent_id,
- store_type=StoreType.GRAPHS,
- name=name,
- summary=summary,
- description_embedding=description_embedding,
- findings=findings,
- rating=rating,
- rating_explanation=rating_explanation,
- )
- @telemetry_event("update_community")
- async def update_community(
- self,
- community_id: UUID,
- name: Optional[str],
- summary: Optional[str],
- findings: Optional[list[str]],
- rating: Optional[float],
- rating_explanation: Optional[str],
- ) -> Community:
- summary_embedding = None
- if summary is not None:
- summary_embedding = str(
- await self.providers.embedding.async_get_embedding(summary)
- )
- return await self.providers.database.graphs_handler.communities.update(
- community_id=community_id,
- store_type=StoreType.GRAPHS,
- name=name,
- summary=summary,
- summary_embedding=summary_embedding,
- findings=findings,
- rating=rating,
- rating_explanation=rating_explanation,
- )
- @telemetry_event("delete_community")
- async def delete_community(
- self,
- parent_id: UUID,
- community_id: UUID,
- ) -> None:
- await self.providers.database.graphs_handler.communities.delete(
- parent_id=parent_id,
- community_id=community_id,
- )
- @telemetry_event("list_communities")
- async def list_communities(
- self,
- collection_id: UUID,
- offset: int,
- limit: int,
- ):
- return await self.providers.database.graphs_handler.communities.get(
- parent_id=collection_id,
- store_type=StoreType.GRAPHS,
- offset=offset,
- limit=limit,
- )
- @telemetry_event("get_communities")
- async def get_communities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- community_ids: Optional[list[UUID]] = None,
- community_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- return await self.providers.database.graphs_handler.get_communities(
- parent_id=parent_id,
- offset=offset,
- limit=limit,
- community_ids=community_ids,
- include_embeddings=include_embeddings,
- )
- # @telemetry_event("create_new_graph")
- # async def create_new_graph(
- # self,
- # collection_id: UUID,
- # user_id: UUID,
- # name: Optional[str],
- # description: str = "",
- # ) -> GraphResponse:
- # return await self.providers.database.graphs_handler.create(
- # collection_id=collection_id,
- # user_id=user_id,
- # name=name,
- # description=description,
- # graph_id=collection_id,
- # )
- async def list_graphs(
- self,
- offset: int,
- limit: int,
- # user_ids: Optional[list[UUID]] = None,
- graph_ids: Optional[list[UUID]] = None,
- collection_id: Optional[UUID] = None,
- ) -> dict[str, list[GraphResponse] | int]:
- return await self.providers.database.graphs_handler.list_graphs(
- offset=offset,
- limit=limit,
- # filter_user_ids=user_ids,
- filter_graph_ids=graph_ids,
- filter_collection_id=collection_id,
- )
- @telemetry_event("update_graph")
- async def update_graph(
- self,
- collection_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ) -> GraphResponse:
- return await self.providers.database.graphs_handler.update(
- collection_id=collection_id,
- name=name,
- description=description,
- )
- @telemetry_event("reset_graph_v3")
- async def reset_graph_v3(self, id: UUID) -> bool:
- await self.providers.database.graphs_handler.reset(
- parent_id=id,
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=id,
- status_type="graph_cluster_status",
- status=KGEnrichmentStatus.PENDING,
- )
- return True
- @telemetry_event("get_document_ids_for_create_graph")
- async def get_document_ids_for_create_graph(
- self,
- collection_id: UUID,
- force_kg_creation: bool = False,
- **kwargs,
- ):
- document_status_filter = [
- KGExtractionStatus.PENDING,
- KGExtractionStatus.FAILED,
- ]
- if force_kg_creation:
- document_status_filter += [
- KGExtractionStatus.PROCESSING,
- ]
- return await self.providers.database.documents_handler.get_document_ids_by_status(
- status_type="extraction_status",
- status=[str(ele) for ele in document_status_filter],
- collection_id=collection_id,
- )
- @telemetry_event("kg_entity_description")
- async def kg_entity_description(
- self,
- document_id: UUID,
- max_description_input_length: int,
- **kwargs,
- ):
- start_time = time.time()
- logger.info(
- f"KGService: Running kg_entity_description for document {document_id}"
- )
- entity_count = (
- await self.providers.database.graphs_handler.get_entity_count(
- document_id=document_id,
- distinct=True,
- entity_table_name="documents_entities",
- )
- )
- logger.info(
- f"KGService: Found {entity_count} entities in document {document_id}"
- )
- # TODO - Do not hardcode the batch size,
- # make it a configurable parameter at runtime & server-side defaults
- # process 256 entities at a time
- num_batches = math.ceil(entity_count / 256)
- logger.info(
- f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
- )
- all_results = []
- for i in range(num_batches):
- logger.info(
- f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
- )
- node_descriptions = await self.pipes.graph_description_pipe.run(
- input=self.pipes.graph_description_pipe.Input(
- message={
- "offset": i * 256,
- "limit": 256,
- "max_description_input_length": max_description_input_length,
- "document_id": document_id,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- all_results.append(await _collect_results(node_descriptions))
- logger.info(
- f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.SUCCESS,
- )
- logger.info(
- f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
- )
- return all_results
- @telemetry_event("kg_clustering")
- async def kg_clustering(
- self,
- collection_id: UUID,
- # graph_id: UUID,
- generation_config: GenerationConfig,
- leiden_params: dict,
- **kwargs,
- ):
- logger.info(
- f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
- )
- clustering_result = await self.pipes.graph_clustering_pipe.run(
- input=self.pipes.graph_clustering_pipe.Input(
- message={
- "collection_id": collection_id,
- "generation_config": generation_config,
- "leiden_params": leiden_params,
- "logger": logger,
- "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(clustering_result)
- @telemetry_event("kg_community_summary")
- async def kg_community_summary(
- self,
- offset: int,
- limit: int,
- max_summary_input_length: int,
- generation_config: GenerationConfig,
- collection_id: UUID | None,
- # graph_id: UUID | None,
- **kwargs,
- ):
- summary_results = await self.pipes.graph_community_summary_pipe.run(
- input=self.pipes.graph_community_summary_pipe.Input(
- message={
- "offset": offset,
- "limit": limit,
- "generation_config": generation_config,
- "max_summary_input_length": max_summary_input_length,
- "collection_id": collection_id,
- # "graph_id": graph_id,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(summary_results)
- @telemetry_event("delete_graph_for_documents")
- async def delete_graph_for_documents(
- self,
- document_ids: list[UUID],
- **kwargs,
- ):
- # TODO: Implement this, as it needs some checks.
- raise NotImplementedError
- @telemetry_event("delete_graph")
- async def delete_graph(
- self,
- collection_id: UUID,
- ):
- return await self.delete(collection_id=collection_id)
- @telemetry_event("delete")
- async def delete(
- self,
- collection_id: UUID,
- **kwargs,
- ):
- return await self.providers.database.graphs_handler.delete(
- collection_id=collection_id,
- )
- @telemetry_event("get_creation_estimate")
- async def get_creation_estimate(
- self,
- graph_creation_settings: KGCreationSettings,
- document_id: Optional[UUID] = None,
- collection_id: Optional[UUID] = None,
- **kwargs,
- ):
- return (
- await self.providers.database.graphs_handler.get_creation_estimate(
- document_id=document_id,
- collection_id=collection_id,
- graph_creation_settings=graph_creation_settings,
- )
- )
- @telemetry_event("get_enrichment_estimate")
- async def get_enrichment_estimate(
- self,
- collection_id: Optional[UUID] = None,
- graph_id: Optional[UUID] = None,
- graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
- **kwargs,
- ):
- if graph_id is None and collection_id is None:
- raise ValueError(
- "Either graph_id or collection_id must be provided"
- )
- return await self.providers.database.graphs_handler.get_enrichment_estimate(
- collection_id=collection_id,
- graph_id=graph_id,
- graph_enrichment_settings=graph_enrichment_settings,
- )
- async def kg_extraction( # type: ignore
- self,
- document_id: UUID,
- generation_config: GenerationConfig,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- chunk_merge_count: int,
- filter_out_existing_chunks: bool = True,
- total_tasks: Optional[int] = None,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
- start_time = time.time()
- logger.info(
- f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
- )
- # Then create the extractions from the results
- limit = 100
- offset = 0
- chunks = []
- while True:
- chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
- document_id=document_id,
- offset=offset,
- limit=limit,
- )
- chunks.extend(
- [
- DocumentChunk(
- id=chunk["id"],
- document_id=chunk["document_id"],
- owner_id=chunk["owner_id"],
- collection_ids=chunk["collection_ids"],
- data=chunk["text"],
- metadata=chunk["metadata"],
- )
- for chunk in chunk_req["results"]
- ]
- )
- if len(chunk_req["results"]) < limit:
- break
- offset += limit
- logger.info(f"Found {len(chunks)} chunks for document {document_id}")
- if len(chunks) == 0:
- logger.info(f"No chunks found for document {document_id}")
- raise R2RException(
- message="No chunks found for document",
- status_code=404,
- )
- if filter_out_existing_chunks:
- existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
- document_id=document_id
- )
- chunks = [
- chunk for chunk in chunks if chunk.id not in existing_chunk_ids
- ]
- logger.info(
- f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
- )
- if len(chunks) == 0:
- logger.info(f"No extractions left for document {document_id}")
- return
- logger.info(
- f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
- )
- # sort the extractions accroding to chunk_order field in metadata in ascending order
- chunks = sorted(
- chunks,
- key=lambda x: x.metadata.get("chunk_order", float("inf")),
- )
- # group these extractions into groups of chunk_merge_count
- grouped_chunks = [
- chunks[i : i + chunk_merge_count]
- for i in range(0, len(chunks), chunk_merge_count)
- ]
- logger.info(
- f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
- )
- tasks = [
- asyncio.create_task(
- self._extract_kg(
- chunks=chunk_group,
- generation_config=generation_config,
- max_knowledge_relationships=max_knowledge_relationships,
- entity_types=entity_types,
- relation_types=relation_types,
- task_id=task_id,
- total_tasks=len(grouped_chunks),
- )
- )
- for task_id, chunk_group in enumerate(grouped_chunks)
- ]
- completed_tasks = 0
- total_tasks = len(tasks)
- logger.info(
- f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
- )
- for completed_task in asyncio.as_completed(tasks):
- try:
- yield await completed_task
- completed_tasks += 1
- if completed_tasks % 100 == 0:
- logger.info(
- f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
- )
- except Exception as e:
- logger.error(f"Error in Extracting KG Relationships: {e}")
- yield R2RDocumentProcessingError(
- document_id=document_id,
- error_message=str(e),
- )
- logger.info(
- f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
- )
- async def _extract_kg(
- self,
- chunks: list[DocumentChunk],
- generation_config: GenerationConfig,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- retries: int = 5,
- delay: int = 2,
- task_id: Optional[int] = None,
- total_tasks: Optional[int] = None,
- ) -> KGExtraction:
- """
- Extracts NER relationships from a extraction with retries.
- """
- # combine all extractions into a single string
- combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore
- response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore
- offset=0,
- limit=1,
- filter_document_ids=[chunks[0].document_id],
- )
- document_summary = (
- response["results"][0].summary if response["results"] else None
- )
- messages = await self.providers.database.prompts_handler.get_message_payload(
- task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
- task_inputs={
- "document_summary": document_summary,
- "input": combined_extraction,
- "max_knowledge_relationships": max_knowledge_relationships,
- "entity_types": "\n".join(entity_types),
- "relation_types": "\n".join(relation_types),
- },
- )
- for attempt in range(retries):
- try:
- response = await self.providers.llm.aget_completion(
- messages,
- generation_config=generation_config,
- )
- kg_extraction = response.choices[0].message.content
- if not kg_extraction:
- raise R2RException(
- "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
- 400,
- )
- entity_pattern = (
- r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
- )
- relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
- async def parse_fn(response_str: str) -> Any:
- entities = re.findall(entity_pattern, response_str)
- if (
- len(kg_extraction)
- > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
- and len(entities) == 0
- ):
- raise R2RException(
- f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
- 400,
- )
- relationships = re.findall(
- relationship_pattern, response_str
- )
- entities_arr = []
- for entity in entities:
- entity_value = entity[0]
- entity_category = entity[1]
- entity_description = entity[2]
- description_embedding = (
- await self.providers.embedding.async_get_embedding(
- entity_description
- )
- )
- entities_arr.append(
- Entity(
- category=entity_category,
- description=entity_description,
- name=entity_value,
- parent_id=chunks[0].document_id,
- chunk_ids=[chunk.id for chunk in chunks],
- description_embedding=description_embedding,
- attributes={},
- )
- )
- relations_arr = []
- for relationship in relationships:
- subject = relationship[0]
- object = relationship[1]
- predicate = relationship[2]
- description = relationship[3]
- weight = float(relationship[4])
- relationship_embedding = (
- await self.providers.embedding.async_get_embedding(
- description
- )
- )
- # check if subject and object are in entities_dict
- relations_arr.append(
- Relationship(
- subject=subject,
- predicate=predicate,
- object=object,
- description=description,
- weight=weight,
- parent_id=chunks[0].document_id,
- chunk_ids=[chunk.id for chunk in chunks],
- attributes={},
- description_embedding=relationship_embedding,
- )
- )
- return entities_arr, relations_arr
- entities, relationships = await parse_fn(kg_extraction)
- return KGExtraction(
- entities=entities,
- relationships=relationships,
- )
- except (
- Exception,
- json.JSONDecodeError,
- KeyError,
- IndexError,
- R2RException,
- ) as e:
- if attempt < retries - 1:
- await asyncio.sleep(delay)
- else:
- logger.warning(
- f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
- )
- logger.info(
- f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
- )
- return KGExtraction(
- entities=[],
- relationships=[],
- )
- async def store_kg_extractions(
- self,
- kg_extractions: list[KGExtraction],
- ):
- """
- Stores a batch of knowledge graph extractions in the graph database.
- """
- for extraction in kg_extractions:
- entities_id_map = {}
- for entity in extraction.entities:
- result = await self.providers.database.graphs_handler.entities.create(
- name=entity.name,
- parent_id=entity.parent_id,
- store_type=StoreType.DOCUMENTS,
- category=entity.category,
- description=entity.description,
- description_embedding=entity.description_embedding,
- chunk_ids=entity.chunk_ids,
- metadata=entity.metadata,
- )
- entities_id_map[entity.name] = result.id
- if extraction.relationships:
- for relationship in extraction.relationships:
- await self.providers.database.graphs_handler.relationships.create(
- subject=relationship.subject,
- subject_id=entities_id_map.get(relationship.subject),
- predicate=relationship.predicate,
- object=relationship.object,
- object_id=entities_id_map.get(relationship.object),
- parent_id=relationship.parent_id,
- description=relationship.description,
- description_embedding=relationship.description_embedding,
- weight=relationship.weight,
- metadata=relationship.metadata,
- store_type=StoreType.DOCUMENTS,
- )
|