routing_search_pipe.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from typing import Any, AsyncGenerator
  2. from uuid import UUID
  3. from core.base import AsyncPipe, AsyncState, ChunkSearchResult, SearchSettings
  4. class RoutingSearchPipe(AsyncPipe):
  5. def __init__(
  6. self,
  7. search_pipes: dict[str, AsyncPipe],
  8. default_strategy: str,
  9. config: AsyncPipe.PipeConfig,
  10. *args,
  11. **kwargs,
  12. ):
  13. super().__init__(config, *args, **kwargs)
  14. self.search_pipes = search_pipes
  15. self.default_strategy = default_strategy
  16. async def _run_logic( # type: ignore
  17. self,
  18. input: AsyncPipe.Input,
  19. state: AsyncState,
  20. run_id: UUID,
  21. search_settings: SearchSettings,
  22. *args: Any,
  23. **kwargs: Any,
  24. ) -> AsyncGenerator[ChunkSearchResult, None]:
  25. search_pipe = self.search_pipes.get(search_settings.search_strategy)
  26. if not search_pipe:
  27. raise ValueError(
  28. f"Search strategy {search_settings.search_strategy} not found"
  29. )
  30. async for result in search_pipe._run_logic( # type: ignore
  31. input,
  32. state,
  33. run_id,
  34. search_settings=search_settings,
  35. *args,
  36. **kwargs,
  37. ):
  38. yield result