test_retrieval_old.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. from unittest.mock import AsyncMock
  2. import pytest
  3. @pytest.fixture
  4. def mock_providers():
  5. """
  6. Return a fake R2RProviders object with all relevant sub-providers mocked.
  7. """
  8. class MockProviders:
  9. def __init__(self):
  10. # Mock the embedding provider
  11. self.completion_embedding = AsyncMock()
  12. self.completion_embedding.async_get_embedding = AsyncMock(
  13. return_value=[0.123] * 768 # pretend vector
  14. )
  15. self.completion_embedding.arerank = AsyncMock(return_value=[])
  16. # Mock the chunk search provider
  17. self.database = AsyncMock()
  18. self.database.chunks_handler.hybrid_search = AsyncMock(
  19. return_value=[]
  20. )
  21. self.database.chunks_handler.semantic_search = AsyncMock(
  22. return_value=[]
  23. )
  24. self.database.chunks_handler.full_text_search = AsyncMock(
  25. return_value=[]
  26. )
  27. # Mock the graph search
  28. self.database.graphs_handler.graph_search = AsyncMock(
  29. return_value=iter([])
  30. )
  31. # Optional: If you want to test agent logic, mock those too
  32. self.llm = AsyncMock()
  33. self.llm.aget_completion = AsyncMock()
  34. self.llm.aget_completion_stream = AsyncMock()
  35. self.database.prompts_handler.get_cached_prompt = AsyncMock(
  36. return_value="(fake hyde template here)"
  37. )
  38. return MockProviders()
  39. @pytest.fixture
  40. def retrieval_service(mock_providers):
  41. """
  42. Construct your RetrievalService with the mocked providers.
  43. """
  44. from core import R2RConfig # adjust import as needed
  45. config = R2RConfig({}) # or however you normally build it
  46. providers = mock_providers
  47. # If your constructor is something like:
  48. from core.main.services import RetrievalService # example
  49. service = RetrievalService(config=config, providers=providers)
  50. return service
  51. # @pytest.mark.asyncio
  52. # async def test_basic_search_calls_once(retrieval_service):
  53. # """
  54. # Ensure that in 'basic' strategy, we only do 1 chunk search & 1 graph search
  55. # (assuming use_semantic_search=True and chunk_settings.enabled=True, etc.).
  56. # """
  57. # s = SearchSettings(
  58. # search_strategy="vanilla", # or "basic"
  59. # use_semantic_search=True,
  60. # chunk_settings={"enabled": True},
  61. # graph_settings={"enabled": True},
  62. # )
  63. # await retrieval_service.search("Aristotle", s)
  64. # # we expect 1 call to chunk search, 1 call to graph search
  65. # chunk_handler = retrieval_service.providers.database.chunks_handler
  66. # graph_handler = retrieval_service.providers.database.graphs_handler
  67. # # Because we used semantic_search or hybrid, let's see which was called:
  68. # # If your code used hybrid by default, check `hybrid_search.call_count`
  69. # assert (
  70. # chunk_handler.hybrid_search.call_count
  71. # + chunk_handler.semantic_search.call_count
  72. # + chunk_handler.full_text_search.call_count
  73. # == 1
  74. # ), "Expected exactly 1 chunk search call in basic mode"
  75. # assert (
  76. # graph_handler.graph_search.call_count == 3
  77. # ), "Expected exactly 1 graph search call in basic mode"
  78. # @pytest.mark.asyncio
  79. # async def test_hyde_search_fans_out_correctly(retrieval_service):
  80. # """
  81. # In 'hyde' strategy with num_sub_queries=2, we should:
  82. # - generate 2 hypothetical docs
  83. # - for each doc => embed alt_text => run chunk+graph => total 2 chunk searches, 2 graph searches
  84. # """
  85. # s = SearchSettings(
  86. # search_strategy="hyde",
  87. # num_sub_queries=2,
  88. # use_semantic_search=True,
  89. # chunk_settings={"enabled": True},
  90. # graph_settings={"enabled": True},
  91. # )
  92. # await retrieval_service.search("Aristotle", s)
  93. # chunk_handler = retrieval_service.providers.database.chunks_handler
  94. # graph_handler = retrieval_service.providers.database.graphs_handler
  95. # embedding_mock = (
  96. # retrieval_service.providers.completion_embedding.async_get_embedding
  97. # )
  98. # # For chunk search, each sub-query => 1 chunk search => total 2 calls
  99. # # (If you see more, maybe your code does something else.)
  100. # total_chunk_calls = (
  101. # chunk_handler.hybrid_search.call_count
  102. # + chunk_handler.semantic_search.call_count
  103. # + chunk_handler.full_text_search.call_count
  104. # )
  105. # print('total_chunk_calls = ', total_chunk_calls)
  106. # # Check how many times we called embedding
  107. # # 1) Possibly the code might embed "Aristotle" once if it re-ranks with user_text (though you might not do that).
  108. # # 2) The code definitely calls embed for each "hyde doc" -> 2 sub queries => 2 calls
  109. # # So you might see 2 or 3 total calls
  110. # assert (
  111. # embedding_mock.call_count >= 2
  112. # ), "We expected at least 2 embeddings for the hyde docs"
  113. # assert (
  114. # total_chunk_calls == 2
  115. # ), f"Expected exactly 2 chunk search calls (got {total_chunk_calls})"
  116. # # For graph search => also 2 calls
  117. # assert (
  118. # graph_handler.graph_search.call_count == 2
  119. # ), f"Expected exactly 2 graph search calls, got {graph_handler.graph_search.call_count}"
  120. # @pytest.mark.asyncio
  121. # async def test_rag_fusion_placeholder(retrieval_service):
  122. # """
  123. # We have a placeholder `_rag_fusion_search`, but it just calls `_basic_search`.
  124. # So let's verify it just triggers 1 chunk search / 1 graph search by default.
  125. # """
  126. # s = SearchSettings(
  127. # search_strategy="rag_fusion",
  128. # # if you haven't actually implemented multi-subqueries, it should
  129. # # just do the same as basic (1 chunk search, 1 graph search).
  130. # use_semantic_search=True,
  131. # chunk_settings={"enabled": True},
  132. # graph_settings={"enabled": True},
  133. # )
  134. # await retrieval_service.search("Aristotle", s)
  135. # chunk_handler = retrieval_service.providers.database.chunks_handler
  136. # graph_handler = retrieval_service.providers.database.graphs_handler
  137. # total_chunk_calls = (
  138. # chunk_handler.hybrid_search.call_count
  139. # + chunk_handler.semantic_search.call_count
  140. # + chunk_handler.full_text_search.call_count
  141. # )
  142. # assert (
  143. # total_chunk_calls == 1
  144. # ), "Placeholder RAG-Fusion should call 1 chunk search"
  145. # assert (
  146. # graph_handler.graph_search.call_count == 3
  147. # ), "Placeholder RAG-Fusion => 1 graph search"