123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- """Base pipeline class for running a sequence of pipes."""
- import asyncio
- import logging
- import traceback
- from typing import Any, AsyncGenerator, Optional
- from ..logger.run_manager import RunManager, manage_run
- from ..pipes.base_pipe import AsyncPipe, AsyncState
- logger = logging.getLogger()
- class AsyncPipeline:
- """Pipeline class for running a sequence of pipes."""
- def __init__(
- self,
- run_manager: Optional[RunManager] = None,
- ):
- self.pipes: list[AsyncPipe] = []
- self.upstream_outputs: list[list[dict[str, str]]] = []
- self.run_manager = run_manager or RunManager()
- self.futures: dict[str, asyncio.Future] = {}
- self.level = 0
- def add_pipe(
- self,
- pipe: AsyncPipe,
- add_upstream_outputs: Optional[list[dict[str, str]]] = None,
- *args,
- **kwargs,
- ) -> None:
- """Add a pipe to the pipeline."""
- self.pipes.append(pipe)
- if not add_upstream_outputs:
- add_upstream_outputs = []
- self.upstream_outputs.append(add_upstream_outputs)
- async def run(
- self,
- input: Any,
- state: Optional[AsyncState] = None,
- stream: bool = False,
- run_manager: Optional[RunManager] = None,
- *args: Any,
- **kwargs: Any,
- ):
- """Run the pipeline."""
- run_manager = run_manager or self.run_manager
- self.state = state or AsyncState()
- current_input = input
- async with manage_run(run_manager):
- try:
- for pipe_num in range(len(self.pipes)):
- config_name = self.pipes[pipe_num].config.name
- self.futures[config_name] = asyncio.Future()
- current_input = self._run_pipe(
- pipe_num,
- current_input,
- run_manager,
- *args,
- **kwargs,
- )
- self.futures[config_name].set_result(current_input)
- except Exception as error:
- # TODO: improve error handling here
- error_trace = traceback.format_exc()
- logger.error(
- f"Pipeline failed with error: {error}\n\nStack trace:\n{error_trace}"
- )
- raise error
- return (
- current_input
- if stream
- else await self._consume_all(current_input)
- )
- async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
- result = []
- async for item in gen:
- if hasattr(
- item, "__aiter__"
- ): # Check if the item is an async generator
- sub_result = await self._consume_all(item)
- result.extend(sub_result)
- else:
- result.append(item)
- return result
- async def _run_pipe(
- self,
- pipe_num: int,
- input: Any,
- run_manager: RunManager,
- *args: Any,
- **kwargs: Any,
- ):
- # Collect inputs, waiting for the necessary futures
- pipe = self.pipes[pipe_num]
- add_upstream_outputs = self.sort_upstream_outputs(
- self.upstream_outputs[pipe_num]
- )
- input_dict = {"message": input}
- # Collection upstream outputs by prev_pipe_name
- grouped_upstream_outputs: dict[str, list] = {}
- for upstream_input in add_upstream_outputs:
- upstream_pipe_name = upstream_input["prev_pipe_name"]
- if upstream_pipe_name not in grouped_upstream_outputs:
- grouped_upstream_outputs[upstream_pipe_name] = []
- grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
- for (
- upstream_pipe_name,
- upstream_inputs,
- ) in grouped_upstream_outputs.items():
- async def resolve_future_output(future):
- result = future.result()
- # consume the async generator
- return [item async for item in result]
- async def replay_items_as_async_gen(items):
- for item in items:
- yield item
- temp_results = await resolve_future_output(
- self.futures[upstream_pipe_name]
- )
- if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
- input_dict["message"] = replay_items_as_async_gen(temp_results)
- for upstream_input in upstream_inputs:
- outputs = await self.state.get(upstream_pipe_name, "output")
- prev_output_field = upstream_input.get(
- "prev_output_field", None
- )
- if not prev_output_field:
- raise ValueError(
- "`prev_output_field` must be specified in the upstream_input"
- )
- input_dict[upstream_input["input_field"]] = outputs[
- prev_output_field
- ]
- async for ele in await pipe.run(
- pipe.Input(**input_dict),
- self.state,
- run_manager,
- *args,
- **kwargs,
- ):
- yield ele
- def sort_upstream_outputs(
- self, add_upstream_outputs: list[dict[str, str]]
- ) -> list[dict[str, str]]:
- pipe_name_to_index = {
- pipe.config.name: index for index, pipe in enumerate(self.pipes)
- }
- def get_pipe_index(upstream_output):
- return pipe_name_to_index[upstream_output["prev_pipe_name"]]
- sorted_outputs = sorted(
- add_upstream_outputs, key=get_pipe_index, reverse=True
- )
- return sorted_outputs
- async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
- """Create an async generator to dequeue requests."""
- while True:
- request = await queue.get()
- if request is None:
- break
- yield request
|