parsing_pipe.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import logging
  2. from typing import AsyncGenerator, Optional
  3. from uuid import UUID
  4. from core.base import AsyncState, DatabaseProvider, Document, DocumentChunk
  5. from core.base.abstractions import R2RDocumentProcessingError
  6. from core.base.pipes.base_pipe import AsyncPipe
  7. from core.base.providers.ingestion import IngestionProvider
  8. from core.utils import generate_extraction_id
  9. logger = logging.getLogger()
  10. class ParsingPipe(AsyncPipe):
  11. class Input(AsyncPipe.Input):
  12. message: Document
  13. def __init__(
  14. self,
  15. database_provider: DatabaseProvider,
  16. ingestion_provider: IngestionProvider,
  17. config: AsyncPipe.PipeConfig,
  18. *args,
  19. **kwargs,
  20. ):
  21. super().__init__(
  22. config,
  23. *args,
  24. **kwargs,
  25. )
  26. self.database_provider = database_provider
  27. self.ingestion_provider = ingestion_provider
  28. async def _parse(
  29. self,
  30. document: Document,
  31. run_id: UUID,
  32. version: str,
  33. ingestion_config_override: Optional[dict],
  34. ) -> AsyncGenerator[DocumentChunk, None]:
  35. try:
  36. ingestion_config_override = ingestion_config_override or {}
  37. override_provider = ingestion_config_override.pop("provider", None)
  38. if (
  39. override_provider
  40. and override_provider
  41. != self.ingestion_provider.config.provider
  42. ):
  43. raise ValueError(
  44. f"Provider '{override_provider}' does not match ingestion provider '{self.ingestion_provider.config.provider}'."
  45. )
  46. if result := await self.database_provider.files_handler.retrieve_file(
  47. document.id
  48. ):
  49. file_name, file_wrapper, file_size = result
  50. with file_wrapper as file_content_stream:
  51. file_content = file_content_stream.read()
  52. async for extraction in self.ingestion_provider.parse( # type: ignore
  53. file_content, document, ingestion_config_override
  54. ):
  55. id = generate_extraction_id(extraction.id, version=version)
  56. extraction.id = id
  57. extraction.metadata["version"] = version
  58. yield extraction
  59. except Exception as e:
  60. raise R2RDocumentProcessingError(
  61. document_id=document.id,
  62. error_message=f"Error parsing document: {str(e)}",
  63. )
  64. async def _run_logic( # type: ignore
  65. self,
  66. input: AsyncPipe.Input,
  67. state: AsyncState,
  68. run_id: UUID,
  69. *args,
  70. **kwargs,
  71. ) -> AsyncGenerator[DocumentChunk, None]:
  72. ingestion_config = kwargs.get("ingestion_config")
  73. async for result in self._parse(
  74. input.message,
  75. run_id,
  76. input.message.metadata.get("version", "v0"),
  77. ingestion_config_override=ingestion_config,
  78. ):
  79. yield result