storage.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import asyncio
  2. import logging
  3. from typing import Any, AsyncGenerator
  4. from uuid import UUID
  5. from core.base import AsyncState, KGExtraction, R2RDocumentProcessingError
  6. from core.base.pipes.base_pipe import AsyncPipe
  7. from core.database import PostgresDatabaseProvider
  8. logger = logging.getLogger()
  9. class GraphStoragePipe(AsyncPipe):
  10. # TODO - Apply correct type hints to storage messages
  11. class Input(AsyncPipe.Input):
  12. message: AsyncGenerator[list[Any], None]
  13. def __init__(
  14. self,
  15. database_provider: PostgresDatabaseProvider,
  16. config: AsyncPipe.PipeConfig,
  17. storage_batch_size: int = 1,
  18. *args,
  19. **kwargs,
  20. ):
  21. """
  22. Initializes the async knowledge graph storage pipe with necessary components and configurations.
  23. """
  24. logger.info(
  25. f"Initializing an `GraphStoragePipe` to store knowledge graph extractions in a graph database."
  26. )
  27. super().__init__(
  28. config,
  29. *args,
  30. **kwargs,
  31. )
  32. self.database_provider = database_provider
  33. self.storage_batch_size = storage_batch_size
  34. async def store(
  35. self,
  36. kg_extractions: list[KGExtraction],
  37. ):
  38. """
  39. Stores a batch of knowledge graph extractions in the graph database.
  40. """
  41. total_entities, total_relationships = 0, 0
  42. for extraction in kg_extractions:
  43. total_entities, total_relationships = (
  44. total_entities + len(extraction.entities),
  45. total_relationships + len(extraction.relationships),
  46. )
  47. if extraction.entities:
  48. if not extraction.entities[0].chunk_ids:
  49. for i in range(len(extraction.entities)):
  50. extraction.entities[i].chunk_ids = extraction.chunk_ids
  51. extraction.entities[i].parent_id = (
  52. extraction.document_id
  53. )
  54. for entity in extraction.entities:
  55. await self.database_provider.graphs_handler.entities.create(
  56. **entity.to_dict()
  57. )
  58. if extraction.relationships:
  59. if not extraction.relationships[0].chunk_ids:
  60. for i in range(len(extraction.relationships)):
  61. extraction.relationships[i].chunk_ids = (
  62. extraction.chunk_ids
  63. )
  64. extraction.relationships[i].document_id = (
  65. extraction.document_id
  66. )
  67. await self.database_provider.graphs_handler.relationships.create(
  68. extraction.relationships,
  69. )
  70. return (total_entities, total_relationships)
  71. async def _run_logic( # type: ignore
  72. self,
  73. input: Input,
  74. state: AsyncState,
  75. run_id: UUID,
  76. *args: Any,
  77. **kwargs: Any,
  78. ) -> AsyncGenerator[list[R2RDocumentProcessingError], None]:
  79. """
  80. Executes the async knowledge graph storage pipe: storing knowledge graph extractions in the graph database.
  81. """
  82. batch_tasks = []
  83. kg_batch: list[KGExtraction] = []
  84. errors = []
  85. async for kg_extraction in input.message:
  86. if isinstance(kg_extraction, R2RDocumentProcessingError):
  87. errors.append(kg_extraction)
  88. continue
  89. kg_batch.append(kg_extraction) # type: ignore
  90. if len(kg_batch) >= self.storage_batch_size:
  91. # Schedule the storage task
  92. batch_tasks.append(
  93. asyncio.create_task(
  94. self.store(kg_batch.copy()),
  95. name=f"kg-store-{self.config.name}",
  96. )
  97. )
  98. kg_batch.clear()
  99. if kg_batch: # Process any remaining extractions
  100. batch_tasks.append(
  101. asyncio.create_task(
  102. self.store(kg_batch.copy()),
  103. name=f"kg-store-{self.config.name}",
  104. )
  105. )
  106. # Wait for all storage tasks to complete
  107. await asyncio.gather(*batch_tasks)
  108. for error in errors:
  109. yield error