multi_search.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from copy import copy, deepcopy
  2. from typing import Any, AsyncGenerator, Optional
  3. from uuid import UUID
  4. from core.base.abstractions import (
  5. ChunkSearchResult,
  6. GenerationConfig,
  7. SearchSettings,
  8. )
  9. from core.base.pipes.base_pipe import AsyncPipe
  10. from ..abstractions.search_pipe import SearchPipe
  11. from .query_transform_pipe import QueryTransformPipe
  12. class MultiSearchPipe(AsyncPipe):
  13. class PipeConfig(AsyncPipe.PipeConfig):
  14. name: str = "multi_search_pipe"
  15. use_rrf: bool = False
  16. rrf_k: int = 60 # RRF constant
  17. num_queries: int = 3
  18. expansion_factor: int = 3 # Factor to expand results before RRF
  19. def __init__(
  20. self,
  21. query_transform_pipe: QueryTransformPipe,
  22. inner_search_pipe: SearchPipe,
  23. config: PipeConfig,
  24. *args,
  25. **kwargs,
  26. ):
  27. self.query_transform_pipe = query_transform_pipe
  28. self.vector_search_pipe = inner_search_pipe
  29. config = config or MultiSearchPipe.PipeConfig(
  30. name=query_transform_pipe.config.name
  31. )
  32. super().__init__(
  33. config,
  34. *args,
  35. **kwargs,
  36. )
  37. self._config: MultiSearchPipe.PipeConfig = config # for type hinting
  38. @property
  39. def config(self) -> PipeConfig:
  40. return self._config
  41. async def _run_logic( # type: ignore
  42. self,
  43. input: Any,
  44. state: Any,
  45. run_id: UUID,
  46. search_settings: SearchSettings,
  47. query_transform_generation_config: Optional[GenerationConfig] = None,
  48. *args: Any,
  49. **kwargs: Any,
  50. ) -> AsyncGenerator[ChunkSearchResult, None]:
  51. query_transform_generation_config = (
  52. query_transform_generation_config
  53. or copy(kwargs.get("rag_generation_config", None))
  54. or GenerationConfig(model="azure/gpt-4o")
  55. )
  56. query_transform_generation_config.stream = False
  57. query_generator = await self.query_transform_pipe.run(
  58. input,
  59. state,
  60. query_transform_generation_config=query_transform_generation_config,
  61. num_query_xf_outputs=self.config.num_queries,
  62. *args,
  63. **kwargs,
  64. )
  65. if self.config.use_rrf:
  66. search_settings.limit = (
  67. self.config.expansion_factor * search_settings.limit
  68. )
  69. results = []
  70. async for search_result in await self.vector_search_pipe.run(
  71. self.vector_search_pipe.Input(message=query_generator),
  72. state,
  73. search_settings=search_settings,
  74. *args,
  75. **kwargs,
  76. ):
  77. results.append(search_result)
  78. # Collection results by their associated queries
  79. grouped_results: dict[str, list[ChunkSearchResult]] = {}
  80. for result in results:
  81. query = result.metadata["associated_query"]
  82. if query not in grouped_results:
  83. grouped_results[query] = []
  84. grouped_results[query].append(result)
  85. fused_results = self.reciprocal_rank_fusion(grouped_results)
  86. for result in fused_results[: search_settings.limit]:
  87. yield result
  88. else:
  89. async for search_result in await self.vector_search_pipe.run(
  90. self.vector_search_pipe.Input(message=query_generator),
  91. state,
  92. search_settings=search_settings,
  93. *args,
  94. **kwargs,
  95. ):
  96. yield search_result
  97. def reciprocal_rank_fusion(
  98. self, all_results: dict[str, list[ChunkSearchResult]]
  99. ) -> list[ChunkSearchResult]:
  100. document_scores: dict[UUID, float] = {}
  101. document_results: dict[UUID, ChunkSearchResult] = {}
  102. document_queries: dict[UUID, set[str]] = {}
  103. for query, results in all_results.items():
  104. for rank, result in enumerate(results, 1):
  105. doc_id = result.chunk_id
  106. if doc_id not in document_scores:
  107. document_scores[doc_id] = 0
  108. document_results[doc_id] = result
  109. set_: set[str] = set()
  110. document_queries[doc_id] = set_
  111. document_scores[doc_id] += 1 / (rank + self.config.rrf_k)
  112. document_queries[doc_id].add(query) # type: ignore
  113. # Sort documents by their RRF score
  114. sorted_docs = sorted(
  115. document_scores.items(), key=lambda x: x[1], reverse=True
  116. )
  117. # Reconstruct ChunkSearchResults with new ranking, RRF score, and associated queries
  118. fused_results = []
  119. for doc_id, rrf_score in sorted_docs:
  120. result = deepcopy(document_results[doc_id])
  121. result.score = (
  122. rrf_score # Replace the original score with the RRF score
  123. )
  124. result.metadata["associated_queries"] = list(
  125. document_queries[doc_id] # type: ignore
  126. ) # Add list of associated queries
  127. result.metadata["is_rrf_score"] = True
  128. if "associated_query" in result.metadata:
  129. del result.metadata[
  130. "associated_query"
  131. ] # Remove the old single associated_query
  132. fused_results.append(result)
  133. return fused_results