search_rag_pipe.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import Any, AsyncGenerator, Optional, Tuple
  2. from uuid import UUID
  3. from core.base import (
  4. AggregateSearchResult,
  5. AsyncPipe,
  6. AsyncState,
  7. CompletionProvider,
  8. DatabaseProvider,
  9. KGSearchResultType,
  10. )
  11. from core.base.abstractions import GenerationConfig, RAGCompletion
  12. from ..abstractions.generator_pipe import GeneratorPipe
  13. 3.0
  14. class RAGPipe(GeneratorPipe):
  15. class Input(AsyncPipe.Input):
  16. message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
  17. def __init__(
  18. self,
  19. llm_provider: CompletionProvider,
  20. database_provider: DatabaseProvider,
  21. config: GeneratorPipe.PipeConfig,
  22. *args,
  23. **kwargs,
  24. ):
  25. super().__init__(
  26. llm_provider,
  27. database_provider,
  28. config,
  29. *args,
  30. **kwargs,
  31. )
  32. self._config: GeneratorPipe.PipeConfig = config
  33. @property
  34. def config(self) -> GeneratorPipe.PipeConfig: # for type hiting
  35. return self._config
  36. async def _run_logic( # type: ignore
  37. self,
  38. input: Input,
  39. state: AsyncState,
  40. run_id: UUID,
  41. rag_generation_config: GenerationConfig,
  42. *args: Any,
  43. **kwargs: Any,
  44. ) -> AsyncGenerator[RAGCompletion, None]:
  45. context = ""
  46. search_iteration = 1
  47. total_results = 0
  48. sel_query = None
  49. async for query, search_results in input.message:
  50. if search_iteration == 1:
  51. sel_query = query
  52. context_piece, total_results = await self._collect_context(
  53. query, search_results, search_iteration, total_results
  54. )
  55. context += context_piece
  56. search_iteration += 1
  57. messages = (
  58. await self.database_provider.prompts_handler.get_message_payload(
  59. system_prompt_name=self.config.system_prompt,
  60. task_prompt_name=self.config.task_prompt,
  61. task_inputs={"query": sel_query, "context": context},
  62. task_prompt_override=kwargs.get("task_prompt_override", None),
  63. )
  64. )
  65. response = await self.llm_provider.aget_completion(
  66. messages=messages, generation_config=rag_generation_config
  67. )
  68. yield RAGCompletion(completion=response, search_results=search_results)
  69. if run_id:
  70. content = response.choices[0].message.content
  71. if not content:
  72. raise ValueError("Response content is empty")
  73. async def _collect_context(
  74. self,
  75. query: str,
  76. results: AggregateSearchResult,
  77. iteration: int,
  78. total_results: int,
  79. ) -> Tuple[str, int]:
  80. context = f"Query:\n{query}\n\n"
  81. if results.chunk_search_results:
  82. context += f"Vector Search Results({iteration}):\n"
  83. it = total_results + 1
  84. for result in results.chunk_search_results:
  85. context += f"[{it}]: {result.text}\n\n"
  86. it += 1
  87. total_results = (
  88. it - 1
  89. ) # Update total_results based on the last index used
  90. if results.graph_search_results:
  91. context += f"Knowledge Graph ({iteration}):\n"
  92. it = total_results + 1
  93. for search_result in results.graph_search_results: # [1]:
  94. # if associated_query := search_results.metadata.get(
  95. # "associated_query"
  96. # ):
  97. # context += f"Query: {associated_query}\n\n"
  98. # context += f"Results:\n"
  99. if search_result.result_type == KGSearchResultType.ENTITY:
  100. context += f"[{it}]: Entity Name - {search_result.content.name}\n\nDescription - {search_result.content.description}\n\n"
  101. elif (
  102. search_result.result_type
  103. == KGSearchResultType.RELATIONSHIP
  104. ):
  105. context += f"[{it}]: Relationship - {search_result.content.subject} - {search_result.content.predicate} - {search_result.content.object}\n\n"
  106. else:
  107. context += f"[{it}]: Community Name - {search_result.content.name}\n\nDescription - {search_result.content.summary}\n\n"
  108. it += 1
  109. total_results = (
  110. it - 1
  111. ) # Update total_results based on the last index used
  112. return context, total_results