rag_pipeline.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import asyncio
  2. import logging
  3. from typing import Any, Optional
  4. from ..base.abstractions import (
  5. GenerationConfig,
  6. GraphSearchSettings,
  7. SearchSettings,
  8. )
  9. from ..base.logger.run_manager import RunManager, manage_run
  10. from ..base.pipeline.base_pipeline import AsyncPipeline
  11. from ..base.pipes.base_pipe import AsyncPipe, AsyncState
  12. from ..base.utils import to_async_generator
  13. logger = logging.getLogger()
  14. class RAGPipeline(AsyncPipeline):
  15. """A pipeline for RAG."""
  16. def __init__(
  17. self,
  18. run_manager: Optional[RunManager] = None,
  19. ):
  20. super().__init__(run_manager)
  21. self._search_pipeline: Optional[AsyncPipeline] = None
  22. self._rag_pipeline: Optional[AsyncPipeline] = None
  23. async def run( # type: ignore
  24. self,
  25. input: Any,
  26. state: Optional[AsyncState],
  27. run_manager: Optional[RunManager] = None,
  28. search_settings: SearchSettings = SearchSettings(),
  29. rag_generation_config: GenerationConfig = GenerationConfig(),
  30. *args: Any,
  31. **kwargs: Any,
  32. ):
  33. if not self._rag_pipeline:
  34. raise ValueError(
  35. "`_rag_pipeline` must be set before running the RAG pipeline"
  36. )
  37. self.state = state or AsyncState()
  38. # TODO - This feels anti-pattern.
  39. run_manager = run_manager or self.run_manager or RunManager()
  40. async with manage_run(run_manager):
  41. if not self._search_pipeline:
  42. raise ValueError(
  43. "`_search_pipeline` must be set before running the RAG pipeline"
  44. )
  45. async def multi_query_generator(input):
  46. tasks = []
  47. async for query in input:
  48. input_kwargs = {
  49. **kwargs,
  50. "search_settings": search_settings,
  51. }
  52. task = asyncio.create_task(
  53. self._search_pipeline.run(
  54. to_async_generator([query]),
  55. state,
  56. False,
  57. run_manager,
  58. *args,
  59. **input_kwargs,
  60. )
  61. )
  62. tasks.append((query, task))
  63. for query, task in tasks:
  64. yield (query, await task)
  65. input_kwargs = {
  66. **kwargs,
  67. "rag_generation_config": rag_generation_config,
  68. }
  69. rag_results = await self._rag_pipeline.run(
  70. multi_query_generator(input),
  71. state,
  72. rag_generation_config.stream,
  73. run_manager,
  74. *args,
  75. **input_kwargs,
  76. )
  77. return rag_results
  78. def add_pipe(
  79. self,
  80. pipe: AsyncPipe,
  81. add_upstream_outputs: Optional[list[dict[str, str]]] = None,
  82. rag_pipe: bool = True,
  83. *args,
  84. **kwargs,
  85. ) -> None:
  86. logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
  87. if not rag_pipe:
  88. raise ValueError(
  89. "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
  90. )
  91. if not self._rag_pipeline:
  92. self._rag_pipeline = AsyncPipeline()
  93. self._rag_pipeline.add_pipe(
  94. pipe, add_upstream_outputs, *args, **kwargs
  95. )
  96. def set_search_pipeline(
  97. self,
  98. _search_pipeline: AsyncPipeline,
  99. *args,
  100. **kwargs,
  101. ) -> None:
  102. logger.debug("Setting search pipeline for the RAGPipeline")
  103. self._search_pipeline = _search_pipeline