search_rag_pipe.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from typing import Any, AsyncGenerator, Optional, Tuple
  2. from uuid import UUID
  3. from core.base import (
  4. AggregateSearchResult,
  5. AsyncPipe,
  6. AsyncState,
  7. CompletionProvider,
  8. DatabaseProvider,
  9. KGSearchResultType,
  10. )
  11. from core.base.abstractions import GenerationConfig, RAGCompletion
  12. from ..abstractions.generator_pipe import GeneratorPipe
  13. class RAGPipe(GeneratorPipe):
  14. class Input(AsyncPipe.Input):
  15. message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
  16. def __init__(
  17. self,
  18. llm_provider: CompletionProvider,
  19. database_provider: DatabaseProvider,
  20. config: GeneratorPipe.PipeConfig,
  21. *args,
  22. **kwargs,
  23. ):
  24. super().__init__(
  25. llm_provider,
  26. database_provider,
  27. config,
  28. *args,
  29. **kwargs,
  30. )
  31. self._config: GeneratorPipe.PipeConfig = config
  32. @property
  33. def config(self) -> GeneratorPipe.PipeConfig: # for type hiting
  34. return self._config
  35. async def _run_logic( # type: ignore
  36. self,
  37. input: Input,
  38. state: AsyncState,
  39. run_id: UUID,
  40. rag_generation_config: GenerationConfig,
  41. *args: Any,
  42. **kwargs: Any,
  43. ) -> AsyncGenerator[RAGCompletion, None]:
  44. context = ""
  45. search_iteration = 1
  46. total_results = 0
  47. sel_query = None
  48. async for query, search_results in input.message:
  49. if search_iteration == 1:
  50. sel_query = query
  51. context_piece, total_results = await self._collect_context(
  52. query, search_results, search_iteration, total_results
  53. )
  54. context += context_piece
  55. search_iteration += 1
  56. messages = (
  57. await self.database_provider.prompts_handler.get_message_payload(
  58. system_prompt_name=self.config.system_prompt,
  59. task_prompt_name=self.config.task_prompt,
  60. task_inputs={"query": sel_query, "context": context},
  61. task_prompt_override=kwargs.get("task_prompt_override", None),
  62. )
  63. )
  64. response = await self.llm_provider.aget_completion(
  65. messages=messages, generation_config=rag_generation_config
  66. )
  67. yield RAGCompletion(completion=response, search_results=search_results)
  68. if run_id:
  69. content = response.choices[0].message.content
  70. if not content:
  71. raise ValueError("Response content is empty")
  72. async def _collect_context(
  73. self,
  74. query: str,
  75. results: AggregateSearchResult,
  76. iteration: int,
  77. total_results: int,
  78. ) -> Tuple[str, int]:
  79. context = f"Query:\n{query}\n\n"
  80. if results.chunk_search_results:
  81. context += f"Vector Search Results({iteration}):\n"
  82. it = total_results + 1
  83. for result in results.chunk_search_results:
  84. context += f"[{it}]: {result.text}\n\n"
  85. it += 1
  86. total_results = (
  87. it - 1
  88. ) # Update total_results based on the last index used
  89. if results.graph_search_results:
  90. context += f"Knowledge Graph ({iteration}):\n"
  91. it = total_results + 1
  92. for search_result in results.graph_search_results: # [1]:
  93. # if associated_query := search_results.metadata.get(
  94. # "associated_query"
  95. # ):
  96. # context += f"Query: {associated_query}\n\n"
  97. # context += f"Results:\n"
  98. if search_result.result_type == KGSearchResultType.ENTITY:
  99. context += f"[{it}]: Entity Name - {search_result.content.name}\n\nDescription - {search_result.content.description}\n\n"
  100. elif (
  101. search_result.result_type
  102. == KGSearchResultType.RELATIONSHIP
  103. ):
  104. context += f"[{it}]: Relationship - {search_result.content.subject} - {search_result.content.predicate} - {search_result.content.object}\n\n"
  105. else:
  106. context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"
  107. it += 1
  108. total_results = (
  109. it - 1
  110. ) # Update total_results based on the last index used
  111. return context, total_results