base_pipeline.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """Base pipeline class for running a sequence of pipes."""
  2. import asyncio
  3. import logging
  4. import traceback
  5. from typing import Any, AsyncGenerator, Optional
  6. from ..logger.run_manager import RunManager, manage_run
  7. from ..pipes.base_pipe import AsyncPipe, AsyncState
  8. logger = logging.getLogger()
  9. class AsyncPipeline:
  10. """Pipeline class for running a sequence of pipes."""
  11. def __init__(
  12. self,
  13. run_manager: Optional[RunManager] = None,
  14. ):
  15. self.pipes: list[AsyncPipe] = []
  16. self.upstream_outputs: list[list[dict[str, str]]] = []
  17. self.run_manager = run_manager or RunManager()
  18. self.futures: dict[str, asyncio.Future] = {}
  19. self.level = 0
  20. def add_pipe(
  21. self,
  22. pipe: AsyncPipe,
  23. add_upstream_outputs: Optional[list[dict[str, str]]] = None,
  24. *args,
  25. **kwargs,
  26. ) -> None:
  27. """Add a pipe to the pipeline."""
  28. self.pipes.append(pipe)
  29. if not add_upstream_outputs:
  30. add_upstream_outputs = []
  31. self.upstream_outputs.append(add_upstream_outputs)
  32. async def run(
  33. self,
  34. input: Any,
  35. state: Optional[AsyncState] = None,
  36. stream: bool = False,
  37. run_manager: Optional[RunManager] = None,
  38. *args: Any,
  39. **kwargs: Any,
  40. ):
  41. """Run the pipeline."""
  42. run_manager = run_manager or self.run_manager
  43. self.state = state or AsyncState()
  44. current_input = input
  45. async with manage_run(run_manager):
  46. try:
  47. for pipe_num in range(len(self.pipes)):
  48. config_name = self.pipes[pipe_num].config.name
  49. self.futures[config_name] = asyncio.Future()
  50. current_input = self._run_pipe(
  51. pipe_num,
  52. current_input,
  53. run_manager,
  54. *args,
  55. **kwargs,
  56. )
  57. self.futures[config_name].set_result(current_input)
  58. except Exception as error:
  59. # TODO: improve error handling here
  60. error_trace = traceback.format_exc()
  61. logger.error(
  62. f"Pipeline failed with error: {error}\n\nStack trace:\n{error_trace}"
  63. )
  64. raise error
  65. return (
  66. current_input
  67. if stream
  68. else await self._consume_all(current_input)
  69. )
  70. async def _consume_all(self, gen: AsyncGenerator) -> list[Any]:
  71. result = []
  72. async for item in gen:
  73. if hasattr(
  74. item, "__aiter__"
  75. ): # Check if the item is an async generator
  76. sub_result = await self._consume_all(item)
  77. result.extend(sub_result)
  78. else:
  79. result.append(item)
  80. return result
  81. async def _run_pipe(
  82. self,
  83. pipe_num: int,
  84. input: Any,
  85. run_manager: RunManager,
  86. *args: Any,
  87. **kwargs: Any,
  88. ):
  89. # Collect inputs, waiting for the necessary futures
  90. pipe = self.pipes[pipe_num]
  91. add_upstream_outputs = self.sort_upstream_outputs(
  92. self.upstream_outputs[pipe_num]
  93. )
  94. input_dict = {"message": input}
  95. # Collection upstream outputs by prev_pipe_name
  96. grouped_upstream_outputs: dict[str, list] = {}
  97. for upstream_input in add_upstream_outputs:
  98. upstream_pipe_name = upstream_input["prev_pipe_name"]
  99. if upstream_pipe_name not in grouped_upstream_outputs:
  100. grouped_upstream_outputs[upstream_pipe_name] = []
  101. grouped_upstream_outputs[upstream_pipe_name].append(upstream_input)
  102. for (
  103. upstream_pipe_name,
  104. upstream_inputs,
  105. ) in grouped_upstream_outputs.items():
  106. async def resolve_future_output(future):
  107. result = future.result()
  108. # consume the async generator
  109. return [item async for item in result]
  110. async def replay_items_as_async_gen(items):
  111. for item in items:
  112. yield item
  113. temp_results = await resolve_future_output(
  114. self.futures[upstream_pipe_name]
  115. )
  116. if upstream_pipe_name == self.pipes[pipe_num - 1].config.name:
  117. input_dict["message"] = replay_items_as_async_gen(temp_results)
  118. for upstream_input in upstream_inputs:
  119. outputs = await self.state.get(upstream_pipe_name, "output")
  120. prev_output_field = upstream_input.get(
  121. "prev_output_field", None
  122. )
  123. if not prev_output_field:
  124. raise ValueError(
  125. "`prev_output_field` must be specified in the upstream_input"
  126. )
  127. input_dict[upstream_input["input_field"]] = outputs[
  128. prev_output_field
  129. ]
  130. async for ele in await pipe.run(
  131. pipe.Input(**input_dict),
  132. self.state,
  133. run_manager,
  134. *args,
  135. **kwargs,
  136. ):
  137. yield ele
  138. def sort_upstream_outputs(
  139. self, add_upstream_outputs: list[dict[str, str]]
  140. ) -> list[dict[str, str]]:
  141. pipe_name_to_index = {
  142. pipe.config.name: index for index, pipe in enumerate(self.pipes)
  143. }
  144. def get_pipe_index(upstream_output):
  145. return pipe_name_to_index[upstream_output["prev_pipe_name"]]
  146. sorted_outputs = sorted(
  147. add_upstream_outputs, key=get_pipe_index, reverse=True
  148. )
  149. return sorted_outputs
  150. async def dequeue_requests(queue: asyncio.Queue) -> AsyncGenerator:
  151. """Create an async generator to dequeue requests."""
  152. while True:
  153. request = await queue.get()
  154. if request is None:
  155. break
  156. yield request