123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- import asyncio
- import logging
- from typing import Any, AsyncGenerator
- from uuid import UUID
- from core.base import AsyncState, KGExtraction, R2RDocumentProcessingError
- from core.base.pipes.base_pipe import AsyncPipe
- from core.database import PostgresDatabaseProvider
- logger = logging.getLogger()
- class GraphStoragePipe(AsyncPipe):
- # TODO - Apply correct type hints to storage messages
- class Input(AsyncPipe.Input):
- message: AsyncGenerator[list[Any], None]
- def __init__(
- self,
- database_provider: PostgresDatabaseProvider,
- config: AsyncPipe.PipeConfig,
- storage_batch_size: int = 1,
- *args,
- **kwargs,
- ):
- """
- Initializes the async knowledge graph storage pipe with necessary components and configurations.
- """
- logger.info(
- f"Initializing an `GraphStoragePipe` to store knowledge graph extractions in a graph database."
- )
- super().__init__(
- config,
- *args,
- **kwargs,
- )
- self.database_provider = database_provider
- self.storage_batch_size = storage_batch_size
- async def store(
- self,
- kg_extractions: list[KGExtraction],
- ):
- """
- Stores a batch of knowledge graph extractions in the graph database.
- """
- total_entities, total_relationships = 0, 0
- for extraction in kg_extractions:
- total_entities, total_relationships = (
- total_entities + len(extraction.entities),
- total_relationships + len(extraction.relationships),
- )
- if extraction.entities:
- if not extraction.entities[0].chunk_ids:
- for i in range(len(extraction.entities)):
- extraction.entities[i].chunk_ids = extraction.chunk_ids
- extraction.entities[i].parent_id = (
- extraction.document_id
- )
- for entity in extraction.entities:
- await self.database_provider.graphs_handler.entities.create(
- **entity.to_dict()
- )
- if extraction.relationships:
- if not extraction.relationships[0].chunk_ids:
- for i in range(len(extraction.relationships)):
- extraction.relationships[i].chunk_ids = (
- extraction.chunk_ids
- )
- extraction.relationships[i].document_id = (
- extraction.document_id
- )
- await self.database_provider.graphs_handler.relationships.create(
- extraction.relationships,
- )
- return (total_entities, total_relationships)
- async def _run_logic( # type: ignore
- self,
- input: Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[list[R2RDocumentProcessingError], None]:
- """
- Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
- """
- batch_tasks = []
- kg_batch: list[KGExtraction] = []
- errors = []
- async for kg_extraction in input.message:
- if isinstance(kg_extraction, R2RDocumentProcessingError):
- errors.append(kg_extraction)
- continue
- kg_batch.append(kg_extraction) # type: ignore
- if len(kg_batch) >= self.storage_batch_size:
- # Schedule the storage task
- batch_tasks.append(
- asyncio.create_task(
- self.store(kg_batch.copy()),
- name=f"kg-store-{self.config.name}",
- )
- )
- kg_batch.clear()
- if kg_batch: # Process any remaining extractions
- batch_tasks.append(
- asyncio.create_task(
- self.store(kg_batch.copy()),
- name=f"kg-store-{self.config.name}",
- )
- )
- # Wait for all storage tasks to complete
- await asyncio.gather(*batch_tasks)
- for error in errors:
- yield error
|