chunk_search_pipe.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import json
  2. import logging
  3. from typing import Any, AsyncGenerator
  4. from uuid import UUID
  5. from core.base import (
  6. AsyncPipe,
  7. AsyncState,
  8. ChunkSearchResult,
  9. DatabaseProvider,
  10. EmbeddingProvider,
  11. EmbeddingPurpose,
  12. SearchSettings,
  13. )
  14. from ..abstractions.search_pipe import SearchPipe
  15. logger = logging.getLogger()
  16. class VectorSearchPipe(SearchPipe):
  17. def __init__(
  18. self,
  19. database_provider: DatabaseProvider,
  20. embedding_provider: EmbeddingProvider,
  21. config: SearchPipe.SearchConfig,
  22. *args,
  23. **kwargs,
  24. ):
  25. super().__init__(
  26. config,
  27. *args,
  28. **kwargs,
  29. )
  30. self.embedding_provider = embedding_provider
  31. self.database_provider = database_provider
  32. self._config: SearchPipe.SearchConfig = config
  33. @property
  34. def config(self) -> SearchPipe.SearchConfig:
  35. return self._config
  36. async def search( # type: ignore
  37. self,
  38. message: str,
  39. search_settings: SearchSettings,
  40. *args: Any,
  41. **kwargs: Any,
  42. ) -> AsyncGenerator[ChunkSearchResult, None]:
  43. if search_settings.chunk_settings.enabled == False:
  44. return
  45. search_settings.filters = (
  46. search_settings.filters or self.config.filters
  47. )
  48. search_settings.limit = search_settings.limit or self.config.limit
  49. results = []
  50. query_vector = await self.embedding_provider.async_get_embedding(
  51. message,
  52. purpose=EmbeddingPurpose.QUERY,
  53. )
  54. if (
  55. search_settings.use_fulltext_search
  56. and search_settings.use_semantic_search
  57. ) or search_settings.use_hybrid_search:
  58. search_results = (
  59. await self.database_provider.chunks_handler.hybrid_search(
  60. query_vector=query_vector,
  61. query_text=message,
  62. search_settings=search_settings,
  63. )
  64. )
  65. elif search_settings.use_fulltext_search:
  66. search_results = (
  67. await self.database_provider.chunks_handler.full_text_search(
  68. query_text=message,
  69. search_settings=search_settings,
  70. )
  71. )
  72. elif search_settings.use_semantic_search:
  73. search_results = (
  74. await self.database_provider.chunks_handler.semantic_search(
  75. query_vector=query_vector,
  76. search_settings=search_settings,
  77. )
  78. )
  79. else:
  80. raise ValueError(
  81. "At least one of use_fulltext_search or use_semantic_search must be True"
  82. )
  83. reranked_results = await self.embedding_provider.arerank(
  84. query=message,
  85. results=search_results,
  86. limit=search_settings.limit,
  87. )
  88. if kwargs.get("include_title_if_available", False):
  89. for result in reranked_results:
  90. if title := result.metadata.get("title", None):
  91. text = result.text
  92. result.text = f"Document Title:{title}\n\nText:{text}"
  93. for result in reranked_results:
  94. result.metadata["associated_query"] = message
  95. results.append(result)
  96. yield result
  97. async def _run_logic( # type: ignore
  98. self,
  99. input: AsyncPipe.Input,
  100. state: AsyncState,
  101. run_id: UUID,
  102. search_settings: SearchSettings = SearchSettings(),
  103. *args: Any,
  104. **kwargs: Any,
  105. ) -> AsyncGenerator[ChunkSearchResult, None]:
  106. async for search_request in input.message:
  107. search_results = []
  108. async for result in self.search(
  109. search_request,
  110. search_settings,
  111. *args,
  112. **kwargs,
  113. ):
  114. search_results.append(result)
  115. yield result
  116. await state.update(
  117. self.config.name,
  118. {"output": {"search_results": search_results}},
  119. )
  120. await state.update(
  121. self.config.name,
  122. {
  123. "output": {
  124. "search_query": search_request,
  125. "search_results": search_results,
  126. }
  127. },
  128. )