| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 | import asyncioimport loggingfrom typing import Any, AsyncGeneratorfrom uuid import UUIDfrom core.base import AsyncState, KGExtraction, R2RDocumentProcessingErrorfrom core.base.pipes.base_pipe import AsyncPipefrom core.database import PostgresDatabaseProviderlogger = logging.getLogger()class GraphStoragePipe(AsyncPipe):    # TODO - Apply correct type hints to storage messages    class Input(AsyncPipe.Input):        message: AsyncGenerator[list[Any], None]    def __init__(        self,        database_provider: PostgresDatabaseProvider,        config: AsyncPipe.PipeConfig,        storage_batch_size: int = 1,        *args,        **kwargs,    ):        """        Initializes the async knowledge graph storage pipe with necessary components and configurations.        """        logger.info(            f"Initializing an `GraphStoragePipe` to store knowledge graph extractions in a graph database."        )        super().__init__(            config,            *args,            **kwargs,        )        self.database_provider = database_provider        self.storage_batch_size = storage_batch_size    async def store(        self,        kg_extractions: list[KGExtraction],    ):        """        Stores a batch of knowledge graph extractions in the graph database.        """        total_entities, total_relationships = 0, 0        for extraction in kg_extractions:            total_entities, total_relationships = (                total_entities + len(extraction.entities),                total_relationships + len(extraction.relationships),            )            if extraction.entities:                if not extraction.entities[0].chunk_ids:                    for i in range(len(extraction.entities)):                        extraction.entities[i].chunk_ids = extraction.chunk_ids                        extraction.entities[i].parent_id = (                            extraction.document_id                        )                for entity in extraction.entities:                    await self.database_provider.graphs_handler.entities.create(                        **entity.to_dict()                    )            if extraction.relationships:                if not extraction.relationships[0].chunk_ids:                    for i in range(len(extraction.relationships)):                        extraction.relationships[i].chunk_ids = (                            extraction.chunk_ids                        )                    extraction.relationships[i].document_id = (                        extraction.document_id                    )                await self.database_provider.graphs_handler.relationships.create(                    extraction.relationships,                )            return (total_entities, total_relationships)    async def _run_logic(  # type: ignore        self,        input: Input,        state: AsyncState,        run_id: UUID,        *args: Any,        **kwargs: Any,    ) -> AsyncGenerator[list[R2RDocumentProcessingError], None]:        """        Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.        """        batch_tasks = []        kg_batch: list[KGExtraction] = []        errors = []        async for kg_extraction in input.message:            if isinstance(kg_extraction, R2RDocumentProcessingError):                errors.append(kg_extraction)                continue            kg_batch.append(kg_extraction)  # type: ignore            if len(kg_batch) >= self.storage_batch_size:                # Schedule the storage task                batch_tasks.append(                    asyncio.create_task(                        self.store(kg_batch.copy()),                        name=f"kg-store-{self.config.name}",                    )                )                kg_batch.clear()        if kg_batch:  # Process any remaining extractions            batch_tasks.append(                asyncio.create_task(                    self.store(kg_batch.copy()),                    name=f"kg-store-{self.config.name}",                )            )        # Wait for all storage tasks to complete        await asyncio.gather(*batch_tasks)        for error in errors:            yield error
 |