search_pipeline.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import asyncio
  2. import logging
  3. from asyncio import Queue
  4. from typing import Any, Optional
  5. from ..base.abstractions import (
  6. AggregateSearchResult,
  7. GraphSearchSettings,
  8. SearchSettings,
  9. )
  10. from ..base.logger.run_manager import RunManager, manage_run
  11. from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
  12. from ..base.pipes.base_pipe import AsyncPipe, AsyncState
  13. logger = logging.getLogger()
  14. class SearchPipeline(AsyncPipeline):
  15. """A pipeline for search."""
  16. def __init__(
  17. self,
  18. run_manager: Optional[RunManager] = None,
  19. ):
  20. super().__init__(run_manager)
  21. self._parsing_pipe: Optional[AsyncPipe] = None
  22. self._vector_search_pipeline: Optional[AsyncPipeline] = None
  23. self._kg_search_pipeline: Optional[AsyncPipeline] = None
  24. async def run( # type: ignore
  25. self,
  26. input: Any,
  27. state: Optional[AsyncState],
  28. stream: bool = False,
  29. run_manager: Optional[RunManager] = None,
  30. search_settings: SearchSettings = SearchSettings(),
  31. *args: Any,
  32. **kwargs: Any,
  33. ):
  34. request_state = state or AsyncState()
  35. run_manager = run_manager or self.run_manager
  36. async with manage_run(run_manager):
  37. vector_search_queue: Queue[str] = Queue()
  38. kg_queue: Queue[str] = Queue()
  39. async def enqueue_requests():
  40. async for message in input:
  41. await vector_search_queue.put(message)
  42. await kg_queue.put(message)
  43. await vector_search_queue.put(None)
  44. await kg_queue.put(None)
  45. # Start the document enqueuing process
  46. enqueue_task = asyncio.create_task(enqueue_requests())
  47. # Start the embedding and KG pipelines in parallel
  48. vector_search_task = asyncio.create_task(
  49. self._vector_search_pipeline.run(
  50. dequeue_requests(vector_search_queue),
  51. request_state,
  52. stream,
  53. run_manager,
  54. search_settings=search_settings,
  55. *args,
  56. **kwargs,
  57. )
  58. )
  59. kg_task = asyncio.create_task(
  60. self._kg_search_pipeline.run(
  61. dequeue_requests(kg_queue),
  62. request_state,
  63. stream,
  64. run_manager,
  65. search_settings=search_settings,
  66. *args,
  67. **kwargs,
  68. )
  69. )
  70. await enqueue_task
  71. chunk_search_results = await vector_search_task
  72. kg_results = await kg_task
  73. return AggregateSearchResult(
  74. chunk_search_results=chunk_search_results,
  75. graph_search_results=kg_results,
  76. )
  77. def add_pipe(
  78. self,
  79. pipe: AsyncPipe,
  80. add_upstream_outputs: Optional[list[dict[str, str]]] = None,
  81. kg_search_pipe: bool = False,
  82. vector_search_pipe: bool = False,
  83. *args,
  84. **kwargs,
  85. ) -> None:
  86. logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")
  87. if kg_search_pipe:
  88. if not self._kg_search_pipeline:
  89. self._kg_search_pipeline = AsyncPipeline()
  90. if not self._kg_search_pipeline:
  91. raise ValueError(
  92. "KG search pipeline not found"
  93. ) # for type hinting
  94. self._kg_search_pipeline.add_pipe(
  95. pipe, add_upstream_outputs, *args, **kwargs
  96. )
  97. elif vector_search_pipe:
  98. if not self._vector_search_pipeline:
  99. self._vector_search_pipeline = AsyncPipeline()
  100. if not self._vector_search_pipeline:
  101. raise ValueError(
  102. "Vector search pipeline not found"
  103. ) # for type hinting
  104. self._vector_search_pipeline.add_pipe(
  105. pipe, add_upstream_outputs, *args, **kwargs
  106. )
  107. else:
  108. raise ValueError("Pipe must be a vector search or KG pipe")