test_retrieval_advanced.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import pytest
  2. from r2r import Message, R2RException, SearchMode
  3. # Semantic Search Tests
  4. def test_semantic_search_with_near_duplicates(client):
  5. """Test semantic search can handle and differentiate near-duplicate content"""
  6. # Create two similar but distinct documents
  7. doc1 = client.documents.create(
  8. raw_text="Aristotle was a Greek philosopher who studied logic."
  9. )["results"]["document_id"]
  10. doc2 = client.documents.create(
  11. raw_text="Aristotle, the Greek philosopher, studied formal logic."
  12. )["results"]["document_id"]
  13. resp = client.retrieval.search(
  14. query="Tell me about Aristotle's work in logic",
  15. search_mode="custom",
  16. search_settings={"use_semantic_search": True, "limit": 5},
  17. )
  18. results = resp["results"]["chunk_search_results"]
  19. # Both documents should be returned but with different scores
  20. scores = [r["score"] for r in results if r["document_id"] in [doc1, doc2]]
  21. assert len(scores) == 2, "Expected both similar documents"
  22. assert (
  23. len(set(scores)) == 2
  24. ), "Expected different scores for similar documents"
  25. def test_semantic_search_multilingual(client):
  26. """Test semantic search handles multilingual content"""
  27. # Create documents in different languages
  28. docs = [
  29. ("Aristotle was a philosopher", "English"),
  30. ("Aristóteles fue un filósofo", "Spanish"),
  31. ("アリストテレスは哲学者でした", "Japanese"),
  32. ]
  33. doc_ids = []
  34. for text, lang in docs:
  35. doc_id = client.documents.create(
  36. raw_text=text, metadata={"language": lang}
  37. )["results"]["document_id"]
  38. doc_ids.append(doc_id)
  39. # Query in different languages
  40. queries = [
  41. "Who was Aristotle?",
  42. "¿Quién fue Aristóteles?",
  43. "アリストテレスとは誰でしたか?",
  44. ]
  45. for query in queries:
  46. resp = client.retrieval.search(
  47. query=query,
  48. search_mode="custom",
  49. search_settings={
  50. "use_semantic_search": True,
  51. "limit": len(doc_ids),
  52. },
  53. )
  54. results = resp["results"]["chunk_search_results"]
  55. assert len(results) > 0, f"No results found for query: {query}"
  56. # UNCOMMENT LATER
  57. # # Hybrid Search Tests
  58. # def test_hybrid_search_weight_balance(client):
  59. # """Test hybrid search balances semantic and full-text scores appropriately"""
  60. # # Create a document with high semantic relevance but low keyword match
  61. # semantic_doc = client.documents.create(
  62. # raw_text="The ancient Greek thinker who studied under Plato made significant contributions to logic."
  63. # )["results"]["document_id"]
  64. # # Create a document with high keyword match but low semantic relevance
  65. # keyword_doc = client.documents.create(
  66. # raw_text="Aristotle is a common name in certain regions. This text mentions Aristotle but is not about philosophy."
  67. # )["results"]["document_id"]
  68. # resp = client.retrieval.search(
  69. # query="What were Aristotle's philosophical contributions?",
  70. # search_mode="custom",
  71. # search_settings={
  72. # "use_hybrid_search": True,
  73. # "hybrid_settings": {
  74. # "semantic_weight": 0.7,
  75. # "full_text_weight": 0.3,
  76. # },
  77. # },
  78. # )
  79. # results = resp["results"]["chunk_search_results"]
  80. # # The semantic document should rank higher
  81. # semantic_rank = next(
  82. # i for i, r in enumerate(results) if r["document_id"] == semantic_doc
  83. # )
  84. # keyword_rank = next(
  85. # i for i, r in enumerate(results) if r["document_id"] == keyword_doc
  86. # )
  87. # assert (
  88. # semantic_rank < keyword_rank
  89. # ), "Semantic relevance should outweigh keyword matches"
  90. # RAG Tests
  91. def test_rag_context_window_limits(client):
  92. """Test RAG handles documents at or near context window limits"""
  93. # Create a document that approaches the context window limit
  94. large_text = (
  95. "Aristotle " * 1000
  96. ) # Adjust multiplier based on your context window
  97. doc_id = client.documents.create(raw_text=large_text)["results"][
  98. "document_id"
  99. ]
  100. resp = client.retrieval.rag(
  101. query="Summarize this text about Aristotle",
  102. search_settings={"filters": {"document_id": {"$eq": str(doc_id)}}},
  103. rag_generation_config={"max_tokens": 100},
  104. )
  105. assert "results" in resp, "RAG should handle large context gracefully"
  106. # UNCOMMENT LATER
  107. # def test_rag_empty_chunk_handling(client):
  108. # """Test RAG properly handles empty or whitespace-only chunks"""
  109. # doc_id = client.documents.create(chunks=["", " ", "\n", "Valid content"])[
  110. # "results"
  111. # ]["document_id"]
  112. # resp = client.retrieval.rag(
  113. # query="What is the content?",
  114. # search_settings={"filters": {"document_id": {"$eq": str(doc_id)}}},
  115. # )
  116. # assert "results" in resp, "RAG should handle empty chunks gracefully"
  117. # # Agent Tests
  118. # def test_agent_clarification_requests(client):
  119. # """Test agent's ability to request clarification for ambiguous queries"""
  120. # msg = Message(role="user", content="Compare them")
  121. # resp = client.retrieval.agent(
  122. # message=msg,
  123. # search_settings={"use_semantic_search": True},
  124. # )
  125. # content = resp["results"]["messages"][-1]["content"]
  126. # assert any(
  127. # phrase in content.lower()
  128. # for phrase in [
  129. # "could you clarify",
  130. # "who do you",
  131. # "what would you",
  132. # "please specify",
  133. # ]
  134. # ), "Agent should request clarification for ambiguous queries"
  135. ## TODO - uncomment later
  136. # def test_agent_source_citation_consistency(client):
  137. # """Test agent consistently cites sources across conversation turns"""
  138. # conversation_id = client.conversations.create()["results"]["id"]
  139. # # First turn - asking about a specific topic
  140. # msg1 = Message(role="user", content="What did Aristotle say about ethics?")
  141. # resp1 = client.retrieval.agent(
  142. # message=msg1,
  143. # conversation_id=conversation_id,
  144. # include_title_if_available=True,
  145. # )
  146. # # Second turn - asking for more details
  147. # msg2 = Message(role="user", content="Can you elaborate on that point?")
  148. # resp2 = client.retrieval.agent(
  149. # message=msg2,
  150. # conversation_id=conversation_id,
  151. # include_title_if_available=True,
  152. # )
  153. # # Check that sources are consistently cited across turns
  154. # sources1 = _extract_sources(resp1["results"]["messages"][-1]["content"])
  155. # sources2 = _extract_sources(resp2["results"]["messages"][-1]["content"])
  156. # assert (
  157. # len(sources1) > 0 and len(sources2) > 0
  158. # ), "Both responses should cite sources"
  159. # assert any(
  160. # s in sources2 for s in sources1
  161. # ), "Follow-up should reference some original sources"
  162. ## TODO - uncomment later
  163. # # Error Handling Tests
  164. # def test_malformed_filter_handling(client):
  165. # """Test system properly handles malformed filter conditions"""
  166. # invalid_filters = [
  167. # {"$invalid": {"$eq": "value"}},
  168. # {"field": {"$unsupported": "value"}},
  169. # {"$and": [{"field": "incomplete_operator"}]},
  170. # {"$or": []}, # Empty OR condition
  171. # {"$and": [{}]}, # Empty filter in AND
  172. # ]
  173. # for invalid_filter in invalid_filters:
  174. # with pytest.raises(R2RException) as exc_info:
  175. # client.retrieval.search(
  176. # query="test", search_settings={"filters": invalid_filter}
  177. # )
  178. # assert exc_info.value.status_code in [
  179. # 400,
  180. # 422,
  181. # ], f"Expected validation error for filter: {invalid_filter}"
  182. ## TODO - Uncomment later
  183. # def test_concurrent_search_stability(client):
  184. # """Test system handles concurrent search requests properly"""
  185. # import asyncio
  186. # async def concurrent_searches():
  187. # tasks = []
  188. # for i in range(10): # Adjust number based on system capabilities
  189. # task = asyncio.create_task(
  190. # client.retrieval.search_async(
  191. # query=f"Concurrent test query {i}", search_mode="basic"
  192. # )
  193. # )
  194. # tasks.append(task)
  195. # results = await asyncio.gather(*tasks, return_exceptions=True)
  196. # return results
  197. # results = asyncio.run(concurrent_searches())
  198. # assert all(
  199. # not isinstance(r, Exception) for r in results
  200. # ), "Concurrent searches should complete without errors"
  201. # Helper function for source extraction
  202. def _extract_sources(content: str) -> list[str]:
  203. """Extract source citations from response content"""
  204. # This is a simplified version - implement based on your citation format
  205. import re
  206. return re.findall(r'"([^"]*)"', content)