123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import asyncio
- import logging
- from asyncio import Queue
- from typing import Any, Optional
- from ..base.abstractions import (
- AggregateSearchResult,
- GraphSearchSettings,
- SearchSettings,
- )
- from ..base.logger.run_manager import RunManager, manage_run
- from ..base.pipeline.base_pipeline import AsyncPipeline, dequeue_requests
- from ..base.pipes.base_pipe import AsyncPipe, AsyncState
- logger = logging.getLogger()
- class SearchPipeline(AsyncPipeline):
- """A pipeline for search."""
- def __init__(
- self,
- run_manager: Optional[RunManager] = None,
- ):
- super().__init__(run_manager)
- self._parsing_pipe: Optional[AsyncPipe] = None
- self._vector_search_pipeline: Optional[AsyncPipeline] = None
- self._kg_search_pipeline: Optional[AsyncPipeline] = None
- async def run( # type: ignore
- self,
- input: Any,
- state: Optional[AsyncState],
- stream: bool = False,
- run_manager: Optional[RunManager] = None,
- search_settings: SearchSettings = SearchSettings(),
- *args: Any,
- **kwargs: Any,
- ):
- request_state = state or AsyncState()
- run_manager = run_manager or self.run_manager
- async with manage_run(run_manager):
- vector_search_queue: Queue[str] = Queue()
- kg_queue: Queue[str] = Queue()
- async def enqueue_requests():
- async for message in input:
- await vector_search_queue.put(message)
- await kg_queue.put(message)
- await vector_search_queue.put(None)
- await kg_queue.put(None)
- # Start the document enqueuing process
- enqueue_task = asyncio.create_task(enqueue_requests())
- # Start the embedding and KG pipelines in parallel
- vector_search_task = asyncio.create_task(
- self._vector_search_pipeline.run(
- dequeue_requests(vector_search_queue),
- request_state,
- stream,
- run_manager,
- search_settings=search_settings,
- *args,
- **kwargs,
- )
- )
- kg_task = asyncio.create_task(
- self._kg_search_pipeline.run(
- dequeue_requests(kg_queue),
- request_state,
- stream,
- run_manager,
- search_settings=search_settings,
- *args,
- **kwargs,
- )
- )
- await enqueue_task
- chunk_search_results = await vector_search_task
- kg_results = await kg_task
- return AggregateSearchResult(
- chunk_search_results=chunk_search_results,
- graph_search_results=kg_results,
- )
- def add_pipe(
- self,
- pipe: AsyncPipe,
- add_upstream_outputs: Optional[list[dict[str, str]]] = None,
- kg_search_pipe: bool = False,
- vector_search_pipe: bool = False,
- *args,
- **kwargs,
- ) -> None:
- logger.debug(f"Adding pipe {pipe.config.name} to the SearchPipeline")
- if kg_search_pipe:
- if not self._kg_search_pipeline:
- self._kg_search_pipeline = AsyncPipeline()
- if not self._kg_search_pipeline:
- raise ValueError(
- "KG search pipeline not found"
- ) # for type hinting
- self._kg_search_pipeline.add_pipe(
- pipe, add_upstream_outputs, *args, **kwargs
- )
- elif vector_search_pipe:
- if not self._vector_search_pipeline:
- self._vector_search_pipeline = AsyncPipeline()
- if not self._vector_search_pipeline:
- raise ValueError(
- "Vector search pipeline not found"
- ) # for type hinting
- self._vector_search_pipeline.add_pipe(
- pipe, add_upstream_outputs, *args, **kwargs
- )
- else:
- raise ValueError("Pipe must be a vector search or KG pipe")
|