test_retrieval_advanced.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import uuid
  2. from r2r import R2RClient
  3. # Semantic Search Tests
  4. def test_semantic_search_with_near_duplicates(client: R2RClient):
  5. """Test semantic search can handle and differentiate near-duplicate
  6. content."""
  7. random_1 = str(uuid.uuid4())
  8. random_2 = str(uuid.uuid4())
  9. # Create two similar but distinct documents
  10. doc1 = client.documents.create(
  11. raw_text=
  12. f"Aristotle was a Greek philosopher who studied logic {random_1}."
  13. ).results.document_id
  14. doc2 = client.documents.create(
  15. raw_text=
  16. f"Aristotle, the Greek philosopher, studied formal logic {random_2}."
  17. ).results.document_id
  18. resp = client.retrieval.search(
  19. query="Tell me about Aristotle's work in logic",
  20. search_mode="custom",
  21. search_settings={
  22. "use_semantic_search": True,
  23. "limit": 25
  24. },
  25. )
  26. results = resp.results.chunk_search_results
  27. # Both documents should be returned but with different scores
  28. scores = [
  29. r.score for r in results
  30. if str(r.document_id) in [str(doc1), str(doc2)]
  31. ]
  32. assert len(scores) == 2, "Expected both similar documents"
  33. assert len(
  34. set(scores)) == 2, ("Expected different scores for similar documents")
  35. def test_semantic_search_multilingual(client: R2RClient):
  36. """Test semantic search handles multilingual content."""
  37. # Create documents in different languages
  38. random_1 = str(uuid.uuid4())
  39. random_2 = str(uuid.uuid4())
  40. random_3 = str(uuid.uuid4())
  41. docs = [
  42. (f"Aristotle was a philosopher {random_1}", "English"),
  43. (f"Aristóteles fue un filósofo {random_2}", "Spanish"),
  44. (f"アリストテレスは哲学者でした {random_3}", "Japanese"),
  45. ]
  46. doc_ids = []
  47. for text, lang in docs:
  48. doc_id = client.documents.create(raw_text=text,
  49. metadata={
  50. "language": lang
  51. }).results.document_id
  52. doc_ids.append(doc_id)
  53. # Query in different languages
  54. queries = [
  55. "Who was Aristotle?",
  56. "¿Quién fue Aristóteles?",
  57. "アリストテレスとは誰でしたか?",
  58. ]
  59. for query in queries:
  60. resp = client.retrieval.search(
  61. query=query,
  62. search_mode="custom",
  63. search_settings={
  64. "use_semantic_search": True,
  65. "limit": len(doc_ids),
  66. },
  67. )
  68. results = resp.results.chunk_search_results
  69. assert len(results) > 0, f"No results found for query: {query}"
  70. # UNCOMMENT LATER
  71. # # Hybrid Search Tests
  72. # def test_hybrid_search_weight_balance(client: R2RClient):
  73. # """Test hybrid search balances semantic and full-text scores appropriately"""
  74. # # Create a document with high semantic relevance but low keyword match
  75. # semantic_doc = client.documents.create(
  76. # raw_text="The ancient Greek thinker who studied under Plato made significant contributions to logic."
  77. # ).results.document_id
  78. # # Create a document with high keyword match but low semantic relevance
  79. # keyword_doc = client.documents.create(
  80. # raw_text="Aristotle is a common name in certain regions. This text mentions Aristotle but is not about philosophy."
  81. # ).results.document_id
  82. # resp = client.retrieval.search(
  83. # query="What were Aristotle's philosophical contributions?",
  84. # search_mode="custom",
  85. # search_settings={
  86. # "use_hybrid_search": True,
  87. # "hybrid_settings": {
  88. # "semantic_weight": 0.7,
  89. # "full_text_weight": 0.3,
  90. # },
  91. # },
  92. # )
  93. # results = resp["results"]["chunk_search_results"]
  94. # # The semantic document should rank higher
  95. # semantic_rank = next(
  96. # i for i, r in enumerate(results) if r["document_id"] == semantic_doc
  97. # )
  98. # keyword_rank = next(
  99. # i for i, r in enumerate(results) if r["document_id"] == keyword_doc
  100. # )
  101. # assert (
  102. # semantic_rank < keyword_rank
  103. # ), "Semantic relevance should outweigh keyword matches"
  104. # RAG Tests
  105. def test_rag_context_window_limits(client: R2RClient):
  106. """Test RAG handles documents at or near context window limits."""
  107. # Create a document that approaches the context window limit
  108. random_1 = str(uuid.uuid4())
  109. large_text = ("Aristotle " * 1000
  110. ) # Adjust multiplier based on your context window
  111. doc_id = client.documents.create(
  112. raw_text=f"{large_text} {random_1}").results.document_id
  113. resp = client.retrieval.rag(
  114. query="Summarize this text about Aristotle",
  115. search_settings={"filters": {
  116. "document_id": {
  117. "$eq": str(doc_id)
  118. }
  119. }},
  120. rag_generation_config={"max_tokens": 100},
  121. )
  122. assert resp.results is not None, (
  123. "RAG should handle large context gracefully")
  124. # UNCOMMENT LATER
  125. # def test_rag_empty_chunk_handling(client: R2RClient):
  126. # """Test RAG properly handles empty or whitespace-only chunks"""
  127. # doc_id = client.documents.create(chunks=["", " ", "\n", "Valid content"])[
  128. # "results"
  129. # ]["document_id"]
  130. # resp = client.retrieval.rag(
  131. # query="What is the content?",
  132. # search_settings={"filters": {"document_id": {"$eq": str(doc_id)}}},
  133. # )
  134. # assert "results" in resp, "RAG should handle empty chunks gracefully"
  135. # # Agent Tests
  136. # def test_agent_clarification_requests(client: R2RClient):
  137. # """Test agent's ability to request clarification for ambiguous queries"""
  138. # msg = Message(role="user", content="Compare them")
  139. # resp = client.retrieval.agent(
  140. # message=msg,
  141. # search_settings={"use_semantic_search": True},
  142. # )
  143. # content = resp["results"]["messages"][-1]["content"]
  144. # assert any(
  145. # phrase in content.lower()
  146. # for phrase in [
  147. # "could you clarify",
  148. # "who do you",
  149. # "what would you",
  150. # "please specify",
  151. # ]
  152. # ), "Agent should request clarification for ambiguous queries"
  153. ## TODO - uncomment later
  154. # def test_agent_source_citation_consistency(client: R2RClient):
  155. # """Test agent consistently cites sources across conversation turns"""
  156. # conversation_id = client.conversations.create()["results"]["id"]
  157. # # First turn - asking about a specific topic
  158. # msg1 = Message(role="user", content="What did Aristotle say about ethics?")
  159. # resp1 = client.retrieval.agent(
  160. # message=msg1,
  161. # conversation_id=conversation_id,
  162. # include_title_if_available=True,
  163. # )
  164. # # Second turn - asking for more details
  165. # msg2 = Message(role="user", content="Can you elaborate on that point?")
  166. # resp2 = client.retrieval.agent(
  167. # message=msg2,
  168. # conversation_id=conversation_id,
  169. # include_title_if_available=True,
  170. # )
  171. # # Check that sources are consistently cited across turns
  172. # sources1 = _extract_sources(resp1["results"]["messages"][-1]["content"])
  173. # sources2 = _extract_sources(resp2["results"]["messages"][-1]["content"])
  174. # assert (
  175. # len(sources1) > 0 and len(sources2) > 0
  176. # ), "Both responses should cite sources"
  177. # assert any(
  178. # s in sources2 for s in sources1
  179. # ), "Follow-up should reference some original sources"
  180. ## TODO - uncomment later
  181. # # Error Handling Tests
  182. # def test_malformed_filter_handling(client: R2RClient):
  183. # """Test system properly handles malformed filter conditions"""
  184. # invalid_filters = [
  185. # {"$invalid": {"$eq": "value"}},
  186. # {"field": {"$unsupported": "value"}},
  187. # {"$and": [{"field": "incomplete_operator"}]},
  188. # {"$or": []}, # Empty OR condition
  189. # {"$and": [{}]}, # Empty filter in AND
  190. # ]
  191. # for invalid_filter in invalid_filters:
  192. # with pytest.raises(R2RException) as exc_info:
  193. # client.retrieval.search(
  194. # query="test", search_settings={"filters": invalid_filter}
  195. # )
  196. # assert exc_info.value.status_code in [
  197. # 400,
  198. # 422,
  199. # ], f"Expected validation error for filter: {invalid_filter}"
  200. ## TODO - Uncomment later
  201. # def test_concurrent_search_stability(client: R2RClient):
  202. # """Test system handles concurrent search requests properly"""
  203. # import asyncio
  204. # async def concurrent_searches():
  205. # tasks = []
  206. # for i in range(10): # Adjust number based on system capabilities
  207. # task = asyncio.create_task(
  208. # client.retrieval.search_async(
  209. # query=f"Concurrent test query {i}", search_mode="basic"
  210. # )
  211. # )
  212. # tasks.append(task)
  213. # results = await asyncio.gather(*tasks, return_exceptions=True)
  214. # return results
  215. # results = asyncio.run(concurrent_searches())
  216. # assert all(
  217. # not isinstance(r, Exception) for r in results
  218. # ), "Concurrent searches should complete without errors"
  219. # Helper function for source extraction
  220. def _extract_sources(content: str) -> list[str]:
  221. """Extract source citations from response content."""
  222. # This is a simplified version - implement based on your citation format
  223. import re
  224. return re.findall(r'"([^"]*)"', content)