query_transform_pipe.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import logging
  2. from typing import Any, AsyncGenerator
  3. from uuid import UUID
  4. from core.base import (
  5. AsyncPipe,
  6. AsyncState,
  7. CompletionProvider,
  8. DatabaseProvider,
  9. )
  10. from core.base.abstractions import GenerationConfig
  11. from ..abstractions.generator_pipe import GeneratorPipe
  12. logger = logging.getLogger()
  13. class QueryTransformPipe(GeneratorPipe):
  14. class QueryTransformConfig(GeneratorPipe.PipeConfig):
  15. name: str = "default_query_transform"
  16. system_prompt: str = "default_system"
  17. task_prompt: str = "hyde"
  18. class Input(GeneratorPipe.Input):
  19. message: AsyncGenerator[str, None]
  20. def __init__(
  21. self,
  22. llm_provider: CompletionProvider,
  23. database_provider: DatabaseProvider,
  24. config: QueryTransformConfig,
  25. *args,
  26. **kwargs,
  27. ):
  28. logger.info(f"Initalizing an `QueryTransformPipe` pipe.")
  29. super().__init__(
  30. llm_provider,
  31. database_provider,
  32. config,
  33. *args,
  34. **kwargs,
  35. )
  36. self._config: QueryTransformPipe.QueryTransformConfig = config
  37. @property
  38. def config(self) -> QueryTransformConfig: # type: ignore
  39. return self._config
  40. async def _run_logic( # type: ignore
  41. self,
  42. input: AsyncPipe.Input,
  43. state: AsyncState,
  44. run_id: UUID,
  45. query_transform_generation_config: GenerationConfig,
  46. num_query_xf_outputs: int = 3,
  47. *args: Any,
  48. **kwargs: Any,
  49. ) -> AsyncGenerator[str, None]:
  50. async for query in input.message:
  51. logger.info(
  52. f"Transforming query: {query} into {num_query_xf_outputs} outputs with {self.config.task_prompt}."
  53. )
  54. query_transform_request = await self.database_provider.prompts_handler.get_message_payload(
  55. system_prompt_name=self.config.system_prompt,
  56. task_prompt_name=self.config.task_prompt,
  57. task_inputs={
  58. "message": query,
  59. "num_outputs": num_query_xf_outputs,
  60. },
  61. )
  62. response = await self.llm_provider.aget_completion(
  63. messages=query_transform_request,
  64. generation_config=query_transform_generation_config,
  65. )
  66. content = response.choices[0].message.content
  67. if not content:
  68. logger.error(f"Failed to transform query: {query}. Skipping.")
  69. raise ValueError(f"Failed to transform query: {query}.")
  70. outputs = content.split("\n")
  71. outputs = [
  72. output.strip() for output in outputs if output.strip() != ""
  73. ]
  74. await state.update(
  75. self.config.name, {"output": {"outputs": outputs}}
  76. )
  77. for output in outputs:
  78. logger.info(f"Yielding transformed output: {output}")
  79. yield output