rag_pipeline.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import asyncio
  2. import logging
  3. from typing import Any, Optional
  4. from ..base.abstractions import (
  5. GenerationConfig,
  6. GraphSearchSettings,
  7. SearchSettings,
  8. )
  9. from ..base.logger.base import RunType
  10. from ..base.logger.run_manager import RunManager, manage_run
  11. from ..base.pipeline.base_pipeline import AsyncPipeline
  12. from ..base.pipes.base_pipe import AsyncPipe, AsyncState
  13. from ..base.utils import to_async_generator
  14. logger = logging.getLogger()
  15. class RAGPipeline(AsyncPipeline):
  16. """A pipeline for RAG."""
  17. def __init__(
  18. self,
  19. run_manager: Optional[RunManager] = None,
  20. ):
  21. super().__init__(run_manager)
  22. self._search_pipeline: Optional[AsyncPipeline] = None
  23. self._rag_pipeline: Optional[AsyncPipeline] = None
  24. async def run( # type: ignore
  25. self,
  26. input: Any,
  27. state: Optional[AsyncState],
  28. run_manager: Optional[RunManager] = None,
  29. search_settings: SearchSettings = SearchSettings(),
  30. rag_generation_config: GenerationConfig = GenerationConfig(),
  31. *args: Any,
  32. **kwargs: Any,
  33. ):
  34. if not self._rag_pipeline:
  35. raise ValueError(
  36. "`_rag_pipeline` must be set before running the RAG pipeline"
  37. )
  38. self.state = state or AsyncState()
  39. # TODO - This feels anti-pattern.
  40. run_manager = run_manager or self.run_manager or RunManager()
  41. async with manage_run(run_manager, RunType.RETRIEVAL):
  42. if not self._search_pipeline:
  43. raise ValueError(
  44. "`_search_pipeline` must be set before running the RAG pipeline"
  45. )
  46. async def multi_query_generator(input):
  47. tasks = []
  48. async for query in input:
  49. input_kwargs = {
  50. **kwargs,
  51. "search_settings": search_settings,
  52. }
  53. task = asyncio.create_task(
  54. self._search_pipeline.run(
  55. to_async_generator([query]),
  56. state,
  57. False,
  58. run_manager,
  59. *args,
  60. **input_kwargs,
  61. )
  62. )
  63. tasks.append((query, task))
  64. for query, task in tasks:
  65. yield (query, await task)
  66. input_kwargs = {
  67. **kwargs,
  68. "rag_generation_config": rag_generation_config,
  69. }
  70. rag_results = await self._rag_pipeline.run(
  71. multi_query_generator(input),
  72. state,
  73. rag_generation_config.stream,
  74. run_manager,
  75. *args,
  76. **input_kwargs,
  77. )
  78. return rag_results
  79. def add_pipe(
  80. self,
  81. pipe: AsyncPipe,
  82. add_upstream_outputs: Optional[list[dict[str, str]]] = None,
  83. rag_pipe: bool = True,
  84. *args,
  85. **kwargs,
  86. ) -> None:
  87. logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
  88. if not rag_pipe:
  89. raise ValueError(
  90. "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
  91. )
  92. if not self._rag_pipeline:
  93. self._rag_pipeline = AsyncPipeline()
  94. self._rag_pipeline.add_pipe(
  95. pipe, add_upstream_outputs, *args, **kwargs
  96. )
  97. def set_search_pipeline(
  98. self,
  99. _search_pipeline: AsyncPipeline,
  100. *args,
  101. **kwargs,
  102. ) -> None:
  103. logger.debug("Setting search pipeline for the RAGPipeline")
  104. self._search_pipeline = _search_pipeline