1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093 |
- import asyncio
- import json
- import logging
- import math
- import re
- import time
- from typing import Any, AsyncGenerator, Optional
- from uuid import UUID
- from core.base import (
- DocumentChunk,
- KGExtraction,
- KGExtractionStatus,
- R2RDocumentProcessingError,
- RunManager,
- )
- from core.base.abstractions import (
- Community,
- Entity,
- GenerationConfig,
- KGCreationSettings,
- KGEnrichmentSettings,
- KGEnrichmentStatus,
- KGEntityDeduplicationSettings,
- KGEntityDeduplicationType,
- R2RException,
- Relationship,
- )
- from core.base.api.models import GraphResponse
- from core.telemetry.telemetry_decorator import telemetry_event
- from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
- from ..config import R2RConfig
- from .base import Service
- logger = logging.getLogger()
- MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
- async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
- results = []
- async for res in result_gen:
- results.append(res.json() if hasattr(res, "json") else res)
- return results
- # TODO - Fix naming convention to read `KGService` instead of `KgService`
- # this will require a minor change in how services are registered.
- class KgService(Service):
- def __init__(
- self,
- config: R2RConfig,
- providers: R2RProviders,
- pipes: R2RPipes,
- pipelines: R2RPipelines,
- agents: R2RAgents,
- run_manager: RunManager,
- ):
- super().__init__(
- config,
- providers,
- pipes,
- pipelines,
- agents,
- run_manager,
- )
- @telemetry_event("kg_relationships_extraction")
- async def kg_relationships_extraction(
- self,
- document_id: UUID,
- generation_config: GenerationConfig,
- chunk_merge_count: int,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- **kwargs,
- ):
- try:
- logger.info(
- f"KGService: Processing document {document_id} for KG extraction"
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.PROCESSING,
- )
- relationships = await self.pipes.kg_relationships_extraction_pipe.run(
- input=self.pipes.kg_relationships_extraction_pipe.Input(
- message={
- "document_id": document_id,
- "generation_config": generation_config,
- "chunk_merge_count": chunk_merge_count,
- "max_knowledge_relationships": max_knowledge_relationships,
- "entity_types": entity_types,
- "relation_types": relation_types,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- logger.info(
- f"KGService: Finished processing document {document_id} for KG extraction"
- )
- result_gen = await self.pipes.kg_storage_pipe.run(
- input=self.pipes.kg_storage_pipe.Input(message=relationships),
- state=None,
- run_manager=self.run_manager,
- )
- except Exception as e:
- logger.error(f"KGService: Error in kg_extraction: {e}")
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.FAILED,
- )
- raise e
- return await _collect_results(result_gen)
- @telemetry_event("create_entity")
- async def create_entity(
- self,
- name: str,
- description: str,
- parent_id: UUID,
- category: Optional[str] = None,
- metadata: Optional[dict] = None,
- ) -> Entity:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return await self.providers.database.graphs_handler.entities.create(
- name=name,
- parent_id=parent_id,
- store_type="graphs", # type: ignore
- category=category,
- description=description,
- description_embedding=description_embedding,
- metadata=metadata,
- )
- @telemetry_event("update_entity")
- async def update_entity(
- self,
- entity_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- category: Optional[str] = None,
- metadata: Optional[dict] = None,
- ) -> Entity:
- description_embedding = None
- if description is not None:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return await self.providers.database.graphs_handler.entities.update(
- entity_id=entity_id,
- store_type="graphs", # type: ignore
- name=name,
- description=description,
- description_embedding=description_embedding,
- category=category,
- metadata=metadata,
- )
- @telemetry_event("delete_entity")
- async def delete_entity(
- self,
- parent_id: UUID,
- entity_id: UUID,
- ):
- return await self.providers.database.graphs_handler.entities.delete(
- parent_id=parent_id,
- entity_ids=[entity_id],
- store_type="graphs", # type: ignore
- )
- @telemetry_event("get_entities")
- async def get_entities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- entity_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- return await self.providers.database.graphs_handler.get_entities(
- parent_id=parent_id,
- offset=offset,
- limit=limit,
- entity_ids=entity_ids,
- entity_names=entity_names,
- include_embeddings=include_embeddings,
- )
- @telemetry_event("create_relationship")
- async def create_relationship(
- self,
- subject: str,
- subject_id: UUID,
- predicate: str,
- object: str,
- object_id: UUID,
- parent_id: UUID,
- description: str | None = None,
- weight: float | None = 1.0,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Relationship:
- description_embedding = None
- if description:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return (
- await self.providers.database.graphs_handler.relationships.create(
- subject=subject,
- subject_id=subject_id,
- predicate=predicate,
- object=object,
- object_id=object_id,
- parent_id=parent_id,
- description=description,
- description_embedding=description_embedding,
- weight=weight,
- metadata=metadata,
- store_type="graphs", # type: ignore
- )
- )
- @telemetry_event("delete_relationship")
- async def delete_relationship(
- self,
- parent_id: UUID,
- relationship_id: UUID,
- ):
- return (
- await self.providers.database.graphs_handler.relationships.delete(
- parent_id=parent_id,
- relationship_ids=[relationship_id],
- store_type="graphs", # type: ignore
- )
- )
- @telemetry_event("update_relationship")
- async def update_relationship(
- self,
- relationship_id: UUID,
- subject: Optional[str] = None,
- subject_id: Optional[UUID] = None,
- predicate: Optional[str] = None,
- object: Optional[str] = None,
- object_id: Optional[UUID] = None,
- description: Optional[str] = None,
- weight: Optional[float] = None,
- metadata: Optional[dict[str, Any] | str] = None,
- ) -> Relationship:
- description_embedding = None
- if description is not None:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(description)
- )
- return (
- await self.providers.database.graphs_handler.relationships.update(
- relationship_id=relationship_id,
- subject=subject,
- subject_id=subject_id,
- predicate=predicate,
- object=object,
- object_id=object_id,
- description=description,
- description_embedding=description_embedding,
- weight=weight,
- metadata=metadata,
- store_type="graphs", # type: ignore
- )
- )
- @telemetry_event("get_relationships")
- async def get_relationships(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- relationship_ids: Optional[list[UUID]] = None,
- entity_names: Optional[list[str]] = None,
- ):
- return await self.providers.database.graphs_handler.relationships.get(
- parent_id=parent_id,
- store_type="graphs", # type: ignore
- offset=offset,
- limit=limit,
- relationship_ids=relationship_ids,
- entity_names=entity_names,
- )
- @telemetry_event("create_community")
- async def create_community(
- self,
- parent_id: UUID,
- name: str,
- summary: str,
- findings: Optional[list[str]],
- rating: Optional[float],
- rating_explanation: Optional[str],
- ) -> Community:
- description_embedding = str(
- await self.providers.embedding.async_get_embedding(summary)
- )
- return await self.providers.database.graphs_handler.communities.create(
- parent_id=parent_id,
- store_type="graphs", # type: ignore
- name=name,
- summary=summary,
- description_embedding=description_embedding,
- findings=findings,
- rating=rating,
- rating_explanation=rating_explanation,
- )
- @telemetry_event("update_community")
- async def update_community(
- self,
- community_id: UUID,
- name: Optional[str],
- summary: Optional[str],
- findings: Optional[list[str]],
- rating: Optional[float],
- rating_explanation: Optional[str],
- ) -> Community:
- summary_embedding = None
- if summary is not None:
- summary_embedding = str(
- await self.providers.embedding.async_get_embedding(summary)
- )
- return await self.providers.database.graphs_handler.communities.update(
- community_id=community_id,
- store_type="graphs", # type: ignore
- name=name,
- summary=summary,
- summary_embedding=summary_embedding,
- findings=findings,
- rating=rating,
- rating_explanation=rating_explanation,
- )
- @telemetry_event("delete_community")
- async def delete_community(
- self,
- parent_id: UUID,
- community_id: UUID,
- ) -> None:
- await self.providers.database.graphs_handler.communities.delete(
- parent_id=parent_id,
- community_id=community_id,
- )
- @telemetry_event("list_communities")
- async def list_communities(
- self,
- collection_id: UUID,
- offset: int,
- limit: int,
- ):
- return await self.providers.database.graphs_handler.communities.get(
- parent_id=collection_id,
- store_type="graphs", # type: ignore
- offset=offset,
- limit=limit,
- )
- @telemetry_event("get_communities")
- async def get_communities(
- self,
- parent_id: UUID,
- offset: int,
- limit: int,
- community_ids: Optional[list[UUID]] = None,
- community_names: Optional[list[str]] = None,
- include_embeddings: bool = False,
- ):
- return await self.providers.database.graphs_handler.get_communities(
- parent_id=parent_id,
- offset=offset,
- limit=limit,
- community_ids=community_ids,
- include_embeddings=include_embeddings,
- )
- # @telemetry_event("create_new_graph")
- # async def create_new_graph(
- # self,
- # collection_id: UUID,
- # user_id: UUID,
- # name: Optional[str],
- # description: str = "",
- # ) -> GraphResponse:
- # return await self.providers.database.graphs_handler.create(
- # collection_id=collection_id,
- # user_id=user_id,
- # name=name,
- # description=description,
- # graph_id=collection_id,
- # )
- async def list_graphs(
- self,
- offset: int,
- limit: int,
- # user_ids: Optional[list[UUID]] = None,
- graph_ids: Optional[list[UUID]] = None,
- collection_id: Optional[UUID] = None,
- ) -> dict[str, list[GraphResponse] | int]:
- return await self.providers.database.graphs_handler.list_graphs(
- offset=offset,
- limit=limit,
- # filter_user_ids=user_ids,
- filter_graph_ids=graph_ids,
- filter_collection_id=collection_id,
- )
- @telemetry_event("update_graph")
- async def update_graph(
- self,
- collection_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ) -> GraphResponse:
- return await self.providers.database.graphs_handler.update(
- collection_id=collection_id,
- name=name,
- description=description,
- )
- @telemetry_event("reset_graph_v3")
- async def reset_graph_v3(self, id: UUID) -> bool:
- await self.providers.database.graphs_handler.reset(
- parent_id=id,
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=id,
- status_type="graph_cluster_status",
- status=KGEnrichmentStatus.PENDING,
- )
- return True
- @telemetry_event("get_document_ids_for_create_graph")
- async def get_document_ids_for_create_graph(
- self,
- collection_id: UUID,
- force_kg_creation: bool = False,
- **kwargs,
- ):
- document_status_filter = [
- KGExtractionStatus.PENDING,
- KGExtractionStatus.FAILED,
- ]
- if force_kg_creation:
- document_status_filter += [
- KGExtractionStatus.PROCESSING,
- ]
- return await self.providers.database.documents_handler.get_document_ids_by_status(
- status_type="extraction_status",
- status=[str(ele) for ele in document_status_filter],
- collection_id=collection_id,
- )
- @telemetry_event("kg_entity_description")
- async def kg_entity_description(
- self,
- document_id: UUID,
- max_description_input_length: int,
- **kwargs,
- ):
- start_time = time.time()
- logger.info(
- f"KGService: Running kg_entity_description for document {document_id}"
- )
- entity_count = (
- await self.providers.database.graphs_handler.get_entity_count(
- document_id=document_id,
- distinct=True,
- entity_table_name="documents_entities",
- )
- )
- logger.info(
- f"KGService: Found {entity_count} entities in document {document_id}"
- )
- # TODO - Do not hardcode the batch size,
- # make it a configurable parameter at runtime & server-side defaults
- # process 256 entities at a time
- num_batches = math.ceil(entity_count / 256)
- logger.info(
- f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
- )
- all_results = []
- for i in range(num_batches):
- logger.info(
- f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
- )
- node_descriptions = await self.pipes.kg_entity_description_pipe.run(
- input=self.pipes.kg_entity_description_pipe.Input(
- message={
- "offset": i * 256,
- "limit": 256,
- "max_description_input_length": max_description_input_length,
- "document_id": document_id,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- all_results.append(await _collect_results(node_descriptions))
- logger.info(
- f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=document_id,
- status_type="extraction_status",
- status=KGExtractionStatus.SUCCESS,
- )
- logger.info(
- f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
- )
- return all_results
- @telemetry_event("kg_clustering")
- async def kg_clustering(
- self,
- collection_id: UUID,
- # graph_id: UUID,
- generation_config: GenerationConfig,
- leiden_params: dict,
- **kwargs,
- ):
- logger.info(
- f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
- )
- clustering_result = await self.pipes.kg_clustering_pipe.run(
- input=self.pipes.kg_clustering_pipe.Input(
- message={
- "collection_id": collection_id,
- "generation_config": generation_config,
- "leiden_params": leiden_params,
- "logger": logger,
- "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(clustering_result)
- @telemetry_event("kg_community_summary")
- async def kg_community_summary(
- self,
- offset: int,
- limit: int,
- max_summary_input_length: int,
- generation_config: GenerationConfig,
- collection_id: UUID | None,
- # graph_id: UUID | None,
- **kwargs,
- ):
- summary_results = await self.pipes.kg_community_summary_pipe.run(
- input=self.pipes.kg_community_summary_pipe.Input(
- message={
- "offset": offset,
- "limit": limit,
- "generation_config": generation_config,
- "max_summary_input_length": max_summary_input_length,
- "collection_id": collection_id,
- # "graph_id": graph_id,
- "logger": logger,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(summary_results)
- @telemetry_event("delete_graph_for_documents")
- async def delete_graph_for_documents(
- self,
- document_ids: list[UUID],
- **kwargs,
- ):
- # TODO: Implement this, as it needs some checks.
- raise NotImplementedError
- @telemetry_event("delete_graph")
- async def delete_graph(
- self,
- collection_id: UUID,
- cascade: bool,
- **kwargs,
- ):
- return await self.delete_graph_for_collection(
- collection_id=collection_id, cascade=cascade
- )
- @telemetry_event("delete_graph_for_collection")
- async def delete_graph_for_collection(
- self,
- collection_id: UUID,
- cascade: bool,
- **kwargs,
- ):
- return await self.providers.database.graphs_handler.delete_graph_for_collection(
- collection_id=collection_id,
- cascade=cascade,
- )
- @telemetry_event("delete_node_via_document_id")
- async def delete_node_via_document_id(
- self,
- document_id: UUID,
- collection_id: UUID,
- **kwargs,
- ):
- return await self.providers.database.graphs_handler.delete_node_via_document_id(
- document_id=document_id,
- collection_id=collection_id,
- )
- @telemetry_event("get_creation_estimate")
- async def get_creation_estimate(
- self,
- graph_creation_settings: KGCreationSettings,
- document_id: Optional[UUID] = None,
- collection_id: Optional[UUID] = None,
- **kwargs,
- ):
- return (
- await self.providers.database.graphs_handler.get_creation_estimate(
- document_id=document_id,
- collection_id=collection_id,
- graph_creation_settings=graph_creation_settings,
- )
- )
- @telemetry_event("get_enrichment_estimate")
- async def get_enrichment_estimate(
- self,
- collection_id: Optional[UUID] = None,
- graph_id: Optional[UUID] = None,
- graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
- **kwargs,
- ):
- if graph_id is None and collection_id is None:
- raise ValueError(
- "Either graph_id or collection_id must be provided"
- )
- return await self.providers.database.graphs_handler.get_enrichment_estimate(
- collection_id=collection_id,
- graph_id=graph_id,
- graph_enrichment_settings=graph_enrichment_settings,
- )
- @telemetry_event("get_deduplication_estimate")
- async def get_deduplication_estimate(
- self,
- collection_id: UUID,
- kg_deduplication_settings: KGEntityDeduplicationSettings,
- **kwargs,
- ):
- return await self.providers.database.graphs_handler.get_deduplication_estimate(
- collection_id=collection_id,
- kg_deduplication_settings=kg_deduplication_settings,
- )
- @telemetry_event("kg_entity_deduplication")
- async def kg_entity_deduplication(
- self,
- collection_id: UUID,
- graph_id: UUID,
- graph_entity_deduplication_type: KGEntityDeduplicationType,
- graph_entity_deduplication_prompt: str,
- generation_config: GenerationConfig,
- **kwargs,
- ):
- deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run(
- input=self.pipes.kg_entity_deduplication_pipe.Input(
- message={
- "collection_id": collection_id,
- "graph_id": graph_id,
- "graph_entity_deduplication_type": graph_entity_deduplication_type,
- "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
- "generation_config": generation_config,
- **kwargs,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(deduplication_results)
- @telemetry_event("kg_entity_deduplication_summary")
- async def kg_entity_deduplication_summary(
- self,
- collection_id: UUID,
- offset: int,
- limit: int,
- graph_entity_deduplication_type: KGEntityDeduplicationType,
- graph_entity_deduplication_prompt: str,
- generation_config: GenerationConfig,
- **kwargs,
- ):
- logger.info(
- f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
- )
- deduplication_summary_results = await self.pipes.kg_entity_deduplication_summary_pipe.run(
- input=self.pipes.kg_entity_deduplication_summary_pipe.Input(
- message={
- "collection_id": collection_id,
- "offset": offset,
- "limit": limit,
- "graph_entity_deduplication_type": graph_entity_deduplication_type,
- "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
- "generation_config": generation_config,
- }
- ),
- state=None,
- run_manager=self.run_manager,
- )
- return await _collect_results(deduplication_summary_results)
- async def kg_extraction( # type: ignore
- self,
- document_id: UUID,
- generation_config: GenerationConfig,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- chunk_merge_count: int,
- filter_out_existing_chunks: bool = True,
- total_tasks: Optional[int] = None,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
- start_time = time.time()
- logger.info(
- f"KGExtractionPipe: Processing document {document_id} for KG extraction",
- )
- # Then create the extractions from the results
- limit = 100
- offset = 0
- chunks = []
- while True:
- chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
- document_id=document_id,
- offset=offset,
- limit=limit,
- )
- chunks.extend(
- [
- DocumentChunk(
- id=chunk["id"],
- document_id=chunk["document_id"],
- owner_id=chunk["owner_id"],
- collection_ids=chunk["collection_ids"],
- data=chunk["text"],
- metadata=chunk["metadata"],
- )
- for chunk in chunk_req["results"]
- ]
- )
- if len(chunk_req["results"]) < limit:
- break
- offset += limit
- logger.info(f"Found {len(chunks)} chunks for document {document_id}")
- if len(chunks) == 0:
- logger.info(f"No chunks found for document {document_id}")
- raise R2RException(
- message="No chunks found for document",
- status_code=404,
- )
- if filter_out_existing_chunks:
- existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
- document_id=document_id
- )
- chunks = [
- chunk for chunk in chunks if chunk.id not in existing_chunk_ids
- ]
- logger.info(
- f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
- )
- if len(chunks) == 0:
- logger.info(f"No extractions left for document {document_id}")
- return
- logger.info(
- f"KGExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
- )
- # sort the extractions accroding to chunk_order field in metadata in ascending order
- chunks = sorted(
- chunks,
- key=lambda x: x.metadata.get("chunk_order", float("inf")),
- )
- # group these extractions into groups of chunk_merge_count
- grouped_chunks = [
- chunks[i : i + chunk_merge_count]
- for i in range(0, len(chunks), chunk_merge_count)
- ]
- logger.info(
- f"KGExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
- )
- tasks = [
- asyncio.create_task(
- self._extract_kg(
- chunks=chunk_group,
- generation_config=generation_config,
- max_knowledge_relationships=max_knowledge_relationships,
- entity_types=entity_types,
- relation_types=relation_types,
- task_id=task_id,
- total_tasks=len(grouped_chunks),
- )
- )
- for task_id, chunk_group in enumerate(grouped_chunks)
- ]
- completed_tasks = 0
- total_tasks = len(tasks)
- logger.info(
- f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
- )
- for completed_task in asyncio.as_completed(tasks):
- try:
- yield await completed_task
- completed_tasks += 1
- if completed_tasks % 100 == 0:
- logger.info(
- f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
- )
- except Exception as e:
- logger.error(f"Error in Extracting KG Relationships: {e}")
- yield R2RDocumentProcessingError(
- document_id=document_id,
- error_message=str(e),
- )
- logger.info(
- f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
- )
- async def _extract_kg(
- self,
- chunks: list[DocumentChunk],
- generation_config: GenerationConfig,
- max_knowledge_relationships: int,
- entity_types: list[str],
- relation_types: list[str],
- retries: int = 5,
- delay: int = 2,
- task_id: Optional[int] = None,
- total_tasks: Optional[int] = None,
- ) -> KGExtraction:
- """
- Extracts NER relationships from a extraction with retries.
- """
- # combine all extractions into a single string
- combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore
- response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore
- offset=0,
- limit=1,
- filter_document_ids=[chunks[0].document_id],
- )
- document_summary = (
- response["results"][0].summary if response["results"] else None
- )
- messages = await self.providers.database.prompts_handler.get_message_payload(
- task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
- task_inputs={
- "document_summary": document_summary,
- "input": combined_extraction,
- "max_knowledge_relationships": max_knowledge_relationships,
- "entity_types": "\n".join(entity_types),
- "relation_types": "\n".join(relation_types),
- },
- )
- for attempt in range(retries):
- try:
- response = await self.providers.llm.aget_completion(
- messages,
- generation_config=generation_config,
- )
- kg_extraction = response.choices[0].message.content
- if not kg_extraction:
- raise R2RException(
- "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
- 400,
- )
- entity_pattern = (
- r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
- )
- relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
- async def parse_fn(response_str: str) -> Any:
- entities = re.findall(entity_pattern, response_str)
- if (
- len(kg_extraction)
- > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
- and len(entities) == 0
- ):
- raise R2RException(
- f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
- 400,
- )
- relationships = re.findall(
- relationship_pattern, response_str
- )
- entities_arr = []
- for entity in entities:
- entity_value = entity[0]
- entity_category = entity[1]
- entity_description = entity[2]
- description_embedding = (
- await self.providers.embedding.async_get_embedding(
- entity_description
- )
- )
- entities_arr.append(
- Entity(
- category=entity_category,
- description=entity_description,
- name=entity_value,
- parent_id=chunks[0].document_id,
- chunk_ids=[chunk.id for chunk in chunks],
- description_embedding=description_embedding,
- attributes={},
- )
- )
- relations_arr = []
- for relationship in relationships:
- subject = relationship[0]
- object = relationship[1]
- predicate = relationship[2]
- description = relationship[3]
- weight = float(relationship[4])
- relationship_embedding = (
- await self.providers.embedding.async_get_embedding(
- description
- )
- )
- # check if subject and object are in entities_dict
- relations_arr.append(
- Relationship(
- subject=subject,
- predicate=predicate,
- object=object,
- description=description,
- weight=weight,
- parent_id=chunks[0].document_id,
- chunk_ids=[chunk.id for chunk in chunks],
- attributes={},
- description_embedding=relationship_embedding,
- )
- )
- return entities_arr, relations_arr
- entities, relationships = await parse_fn(kg_extraction)
- return KGExtraction(
- entities=entities,
- relationships=relationships,
- )
- except (
- Exception,
- json.JSONDecodeError,
- KeyError,
- IndexError,
- R2RException,
- ) as e:
- if attempt < retries - 1:
- await asyncio.sleep(delay)
- else:
- logger.warning(
- f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
- )
- logger.info(
- f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
- )
- return KGExtraction(
- entities=[],
- relationships=[],
- )
- async def store_kg_extractions(
- self,
- kg_extractions: list[KGExtraction],
- ):
- """
- Stores a batch of knowledge graph extractions in the graph database.
- """
- for extraction in kg_extractions:
- entities_id_map = {}
- for entity in extraction.entities:
- result = await self.providers.database.graphs_handler.entities.create(
- name=entity.name,
- parent_id=entity.parent_id,
- store_type="documents", # type: ignore
- category=entity.category,
- description=entity.description,
- description_embedding=entity.description_embedding,
- chunk_ids=entity.chunk_ids,
- metadata=entity.metadata,
- )
- entities_id_map[entity.name] = result.id
- if extraction.relationships:
- for relationship in extraction.relationships:
- await self.providers.database.graphs_handler.relationships.create(
- subject=relationship.subject,
- subject_id=entities_id_map.get(relationship.subject),
- predicate=relationship.predicate,
- object=relationship.object,
- object_id=entities_id_map.get(relationship.object),
- parent_id=relationship.parent_id,
- description=relationship.description,
- description_embedding=relationship.description_embedding,
- weight=relationship.weight,
- metadata=relationship.metadata,
- store_type="documents", # type: ignore
- )
|