123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import asyncio
- import logging
- from typing import Any, Optional
- from ..base.abstractions import (
- GenerationConfig,
- GraphSearchSettings,
- SearchSettings,
- )
- from ..base.logger.run_manager import RunManager, manage_run
- from ..base.pipeline.base_pipeline import AsyncPipeline
- from ..base.pipes.base_pipe import AsyncPipe, AsyncState
- from ..base.utils import to_async_generator
- logger = logging.getLogger()
- class RAGPipeline(AsyncPipeline):
- """A pipeline for RAG."""
- def __init__(
- self,
- run_manager: Optional[RunManager] = None,
- ):
- super().__init__(run_manager)
- self._search_pipeline: Optional[AsyncPipeline] = None
- self._rag_pipeline: Optional[AsyncPipeline] = None
- async def run( # type: ignore
- self,
- input: Any,
- state: Optional[AsyncState],
- run_manager: Optional[RunManager] = None,
- search_settings: SearchSettings = SearchSettings(),
- rag_generation_config: GenerationConfig = GenerationConfig(),
- *args: Any,
- **kwargs: Any,
- ):
- if not self._rag_pipeline:
- raise ValueError(
- "`_rag_pipeline` must be set before running the RAG pipeline"
- )
- self.state = state or AsyncState()
- # TODO - This feels anti-pattern.
- run_manager = run_manager or self.run_manager or RunManager()
- async with manage_run(run_manager):
- if not self._search_pipeline:
- raise ValueError(
- "`_search_pipeline` must be set before running the RAG pipeline"
- )
- async def multi_query_generator(input):
- tasks = []
- async for query in input:
- input_kwargs = {
- **kwargs,
- "search_settings": search_settings,
- }
- task = asyncio.create_task(
- self._search_pipeline.run(
- to_async_generator([query]),
- state,
- False,
- run_manager,
- *args,
- **input_kwargs,
- )
- )
- tasks.append((query, task))
- for query, task in tasks:
- yield (query, await task)
- input_kwargs = {
- **kwargs,
- "rag_generation_config": rag_generation_config,
- }
- rag_results = await self._rag_pipeline.run(
- multi_query_generator(input),
- state,
- rag_generation_config.stream,
- run_manager,
- *args,
- **input_kwargs,
- )
- return rag_results
- def add_pipe(
- self,
- pipe: AsyncPipe,
- add_upstream_outputs: Optional[list[dict[str, str]]] = None,
- rag_pipe: bool = True,
- *args,
- **kwargs,
- ) -> None:
- logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
- if not rag_pipe:
- raise ValueError(
- "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
- )
- if not self._rag_pipeline:
- self._rag_pipeline = AsyncPipeline()
- self._rag_pipeline.add_pipe(
- pipe, add_upstream_outputs, *args, **kwargs
- )
- def set_search_pipeline(
- self,
- _search_pipeline: AsyncPipeline,
- *args,
- **kwargs,
- ) -> None:
- logger.debug("Setting search pipeline for the RAGPipeline")
- self._search_pipeline = _search_pipeline
|