streaming_rag_pipe.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import logging
  2. from typing import Any, AsyncGenerator, Generator
  3. from uuid import UUID
  4. from core.base import (
  5. AsyncState,
  6. CompletionProvider,
  7. DatabaseProvider,
  8. LLMChatCompletionChunk,
  9. format_search_results_for_llm,
  10. format_search_results_for_stream,
  11. )
  12. from core.base.abstractions import GenerationConfig
  13. from ..abstractions.generator_pipe import GeneratorPipe
  14. logger = logging.getLogger()
  15. class StreamingRAGPipe(GeneratorPipe):
  16. CHUNK_SEARCH_STREAM_MARKER = (
  17. "search" # TODO - change this to vector_search in next major release
  18. )
  19. COMPLETION_STREAM_MARKER = "completion"
  20. def __init__(
  21. self,
  22. llm_provider: CompletionProvider,
  23. database_provider: DatabaseProvider,
  24. config: GeneratorPipe.PipeConfig,
  25. *args,
  26. **kwargs,
  27. ):
  28. super().__init__(
  29. llm_provider,
  30. database_provider,
  31. config,
  32. *args,
  33. **kwargs,
  34. )
  35. self._config: GeneratorPipe.PipeConfig
  36. @property
  37. def config(self) -> GeneratorPipe.PipeConfig:
  38. return self._config
  39. async def _run_logic( # type: ignore
  40. self,
  41. input: GeneratorPipe.Input,
  42. state: AsyncState,
  43. run_id: UUID,
  44. rag_generation_config: GenerationConfig,
  45. *args: Any,
  46. **kwargs: Any,
  47. ) -> AsyncGenerator[str, None]:
  48. context = ""
  49. async for query, search_results in input.message:
  50. result = format_search_results_for_stream(search_results)
  51. yield result
  52. gen_context = format_search_results_for_llm(search_results)
  53. context += gen_context
  54. messages = (
  55. await self.database_provider.prompts_handler.get_message_payload(
  56. system_prompt_name=self.config.system_prompt,
  57. task_prompt_name=self.config.task_prompt,
  58. task_inputs={"query": query, "context": context},
  59. )
  60. )
  61. yield f"<{self.COMPLETION_STREAM_MARKER}>"
  62. response = ""
  63. for chunk in self.llm_provider.get_completion_stream(
  64. messages=messages, generation_config=rag_generation_config
  65. ):
  66. chunk_txt = StreamingRAGPipe._process_chunk(chunk)
  67. response += chunk_txt
  68. yield chunk_txt
  69. yield f"</{self.COMPLETION_STREAM_MARKER}>"
  70. async def _yield_chunks(
  71. self,
  72. start_marker: str,
  73. chunks: Generator[str, None, None],
  74. end_marker: str,
  75. ) -> AsyncGenerator[str, None]:
  76. yield start_marker
  77. for chunk in chunks:
  78. yield chunk
  79. yield end_marker
  80. @staticmethod
  81. def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
  82. return chunk.choices[0].delta.content or ""