embedding_pipe.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import asyncio
  2. import logging
  3. from typing import Any, AsyncGenerator
  4. from core.base import (
  5. AsyncState,
  6. DocumentChunk,
  7. EmbeddingProvider,
  8. R2RDocumentProcessingError,
  9. Vector,
  10. VectorEntry,
  11. )
  12. from core.base.pipes.base_pipe import AsyncPipe
  13. logger = logging.getLogger()
  14. class EmbeddingPipe(AsyncPipe[VectorEntry]):
  15. """
  16. Embeds extractions using a specified embedding model.
  17. """
  18. class Input(AsyncPipe.Input):
  19. message: list[DocumentChunk]
  20. def __init__(
  21. self,
  22. embedding_provider: EmbeddingProvider,
  23. config: AsyncPipe.PipeConfig,
  24. embedding_batch_size: int = 1,
  25. *args,
  26. **kwargs,
  27. ):
  28. super().__init__(config)
  29. self.embedding_provider = embedding_provider
  30. self.embedding_batch_size = embedding_batch_size
  31. async def embed(self, extractions: list[DocumentChunk]) -> list[float]:
  32. return await self.embedding_provider.async_get_embeddings(
  33. [extraction.data for extraction in extractions], # type: ignore
  34. EmbeddingProvider.PipeStage.BASE,
  35. )
  36. async def _process_batch(
  37. self, extraction_batch: list[DocumentChunk]
  38. ) -> list[VectorEntry]:
  39. vectors = await self.embed(extraction_batch)
  40. return [
  41. VectorEntry(
  42. id=extraction.id,
  43. document_id=extraction.document_id,
  44. owner_id=extraction.owner_id,
  45. collection_ids=extraction.collection_ids,
  46. vector=Vector(data=raw_vector),
  47. text=extraction.data, # type: ignore
  48. metadata={
  49. **extraction.metadata,
  50. },
  51. )
  52. for raw_vector, extraction in zip(vectors, extraction_batch)
  53. ]
  54. async def _run_logic( # type: ignore
  55. self,
  56. input: AsyncPipe.Input,
  57. state: AsyncState,
  58. run_id: Any,
  59. *args: Any,
  60. **kwargs: Any,
  61. ) -> AsyncGenerator[VectorEntry, None]:
  62. if not isinstance(input, EmbeddingPipe.Input):
  63. raise ValueError(
  64. f"Invalid input type for embedding pipe: {type(input)}"
  65. )
  66. extraction_batch = []
  67. batch_size = self.embedding_batch_size
  68. concurrent_limit = (
  69. self.embedding_provider.config.concurrent_request_limit
  70. )
  71. tasks = set()
  72. async def process_batch(batch):
  73. return await self._process_batch(batch)
  74. try:
  75. for item in input.message:
  76. extraction_batch.append(item)
  77. if len(extraction_batch) >= batch_size:
  78. tasks.add(
  79. asyncio.create_task(process_batch(extraction_batch))
  80. )
  81. extraction_batch = []
  82. while len(tasks) >= concurrent_limit:
  83. done, tasks = await asyncio.wait(
  84. tasks, return_when=asyncio.FIRST_COMPLETED
  85. )
  86. for task in done:
  87. for vector_entry in await task:
  88. yield vector_entry
  89. if extraction_batch:
  90. tasks.add(asyncio.create_task(process_batch(extraction_batch)))
  91. for future_task in asyncio.as_completed(tasks):
  92. for vector_entry in await future_task:
  93. yield vector_entry
  94. finally:
  95. # Ensure all tasks are completed
  96. if tasks:
  97. await asyncio.gather(*tasks, return_exceptions=True)
  98. async def _process_extraction(
  99. self, extraction: DocumentChunk
  100. ) -> VectorEntry | R2RDocumentProcessingError:
  101. try:
  102. if isinstance(extraction.data, bytes):
  103. raise ValueError(
  104. "extraction data is in bytes format, which is not supported by the embedding provider."
  105. )
  106. vectors = await self.embedding_provider.async_get_embeddings(
  107. [extraction.data],
  108. EmbeddingProvider.PipeStage.BASE,
  109. )
  110. return VectorEntry(
  111. id=extraction.id,
  112. document_id=extraction.document_id,
  113. owner_id=extraction.owner_id,
  114. collection_ids=extraction.collection_ids,
  115. vector=Vector(data=vectors[0]),
  116. text=extraction.data,
  117. metadata={**extraction.metadata},
  118. )
  119. except Exception as e:
  120. logger.error(f"Error processing extraction: {e}")
  121. return R2RDocumentProcessingError(
  122. error_message=str(e),
  123. document_id=extraction.document_id,
  124. )