test_retrieval_cli.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """
  2. Tests for the retrieval commands in the CLI.
  3. - search
  4. - rag
  5. """
  6. import json
  7. import tempfile
  8. import pytest
  9. from click.testing import CliRunner
  10. from cli.commands.documents import create as create_document
  11. from cli.commands.retrieval import rag, search
  12. from r2r import R2RAsyncClient
  13. from tests.cli.async_invoke import async_invoke
  14. def extract_json_block(output: str) -> dict:
  15. """Extract and parse the first valid JSON object found in the output."""
  16. start = output.find("{")
  17. if start == -1:
  18. raise ValueError("No JSON object start found in output")
  19. brace_count = 0
  20. for i, char in enumerate(output[start:], start=start):
  21. if char == "{":
  22. brace_count += 1
  23. elif char == "}":
  24. brace_count -= 1
  25. if brace_count == 0:
  26. json_str = output[start : i + 1].strip()
  27. return json.loads(json_str)
  28. raise ValueError("No complete JSON object found in output")
  29. async def create_test_document(
  30. runner: CliRunner, client: R2RAsyncClient
  31. ) -> str:
  32. """Helper function to create a test document and return its ID."""
  33. with tempfile.NamedTemporaryFile(
  34. mode="w", suffix=".txt", delete=False
  35. ) as f:
  36. f.write(
  37. "This is a test document about artificial intelligence and machine learning. "
  38. "AI systems can be trained on large datasets to perform various tasks."
  39. )
  40. temp_path = f.name
  41. create_result = await async_invoke(
  42. runner, create_document, temp_path, obj=client
  43. )
  44. response = extract_json_block(create_result.stdout_bytes.decode())
  45. return response["results"]["document_id"]
  46. @pytest.mark.asyncio
  47. async def test_basic_search():
  48. """Test basic search functionality."""
  49. client = R2RAsyncClient(base_url="http://localhost:7272")
  50. runner = CliRunner(mix_stderr=False)
  51. # Create test document first
  52. document_id = await create_test_document(runner, client)
  53. try:
  54. # Test basic search
  55. search_result = await async_invoke(
  56. runner,
  57. search,
  58. "--query",
  59. "artificial intelligence",
  60. "--limit",
  61. "5",
  62. obj=client,
  63. )
  64. assert search_result.exit_code == 0
  65. assert "Vector search results:" in search_result.stdout_bytes.decode()
  66. finally:
  67. # Cleanup will be handled by document deletion in a real implementation
  68. pass
  69. @pytest.mark.asyncio
  70. async def test_search_with_filters():
  71. """Test search with filters."""
  72. client = R2RAsyncClient(base_url="http://localhost:7272")
  73. runner = CliRunner(mix_stderr=False)
  74. document_id = await create_test_document(runner, client)
  75. try:
  76. filters = json.dumps({"document_id": {"$in": [document_id]}})
  77. search_result = await async_invoke(
  78. runner,
  79. search,
  80. "--query",
  81. "machine learning",
  82. "--filters",
  83. filters,
  84. "--limit",
  85. "5",
  86. obj=client,
  87. )
  88. assert search_result.exit_code == 0
  89. output = search_result.stdout_bytes.decode()
  90. assert "Vector search results:" in output
  91. assert document_id in output
  92. finally:
  93. pass
  94. @pytest.mark.asyncio
  95. async def test_search_with_advanced_options():
  96. """Test search with advanced options."""
  97. client = R2RAsyncClient(base_url="http://localhost:7272")
  98. runner = CliRunner(mix_stderr=False)
  99. document_id = await create_test_document(runner, client)
  100. try:
  101. search_result = await async_invoke(
  102. runner,
  103. search,
  104. "--query",
  105. "AI systems",
  106. "--use-hybrid-search",
  107. "true",
  108. "--search-strategy",
  109. "vanilla",
  110. "--graph-search-enabled",
  111. "true",
  112. "--chunk-search-enabled",
  113. "true",
  114. obj=client,
  115. )
  116. assert search_result.exit_code == 0
  117. output = search_result.stdout_bytes.decode()
  118. assert "Vector search results:" in output
  119. finally:
  120. pass
  121. @pytest.mark.asyncio
  122. async def test_basic_rag():
  123. """Test basic RAG functionality."""
  124. client = R2RAsyncClient(base_url="http://localhost:7272")
  125. runner = CliRunner(mix_stderr=False)
  126. document_id = await create_test_document(runner, client)
  127. try:
  128. rag_result = await async_invoke(
  129. runner,
  130. rag,
  131. "--query",
  132. "What is this document about?",
  133. obj=client,
  134. )
  135. assert rag_result.exit_code == 0
  136. finally:
  137. pass
  138. @pytest.mark.asyncio
  139. async def test_rag_with_streaming():
  140. """Test RAG with streaming enabled."""
  141. client = R2RAsyncClient(base_url="http://localhost:7272")
  142. runner = CliRunner(mix_stderr=False)
  143. document_id = await create_test_document(runner, client)
  144. try:
  145. rag_result = await async_invoke(
  146. runner,
  147. rag,
  148. "--query",
  149. "What is this document about?",
  150. "--stream",
  151. obj=client,
  152. )
  153. assert rag_result.exit_code == 0
  154. finally:
  155. pass
  156. @pytest.mark.asyncio
  157. async def test_rag_with_model_specification():
  158. """Test RAG with specific model."""
  159. client = R2RAsyncClient(base_url="http://localhost:7272")
  160. runner = CliRunner(mix_stderr=False)
  161. document_id = await create_test_document(runner, client)
  162. try:
  163. rag_result = await async_invoke(
  164. runner,
  165. rag,
  166. "--query",
  167. "What is this document about?",
  168. "--rag-model",
  169. "azure/gpt-4o-mini",
  170. obj=client,
  171. )
  172. assert rag_result.exit_code == 0
  173. finally:
  174. pass