test_agent.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import time
  2. import uuid
  3. from r2r import R2RClient
  4. def test_agent_basic_response(client, test_collection):
  5. """Test basic agent response with minimal configuration."""
  6. response = client.retrieval.agent(
  7. message={"role": "user", "content": "Who was Aristotle?"},
  8. rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  9. )
  10. assert response.results.messages[-1].content, "Agent should provide a response"
  11. assert "Aristotle" in response.results.messages[-1].content, "Response should be relevant to query"
  12. def test_agent_conversation_memory(client, test_collection):
  13. """Test agent maintains conversation context across multiple turns."""
  14. conversation_id = client.conversations.create().results.id
  15. # First turn
  16. response1 = client.retrieval.agent(
  17. message={"role": "user", "content": "Who was Aristotle?"},
  18. conversation_id=str(conversation_id),
  19. rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  20. )
  21. # Second turn with follow-up that requires memory of first turn
  22. response2 = client.retrieval.agent(
  23. message={"role": "user", "content": "What were his main contributions?"},
  24. conversation_id=str(conversation_id),
  25. rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  26. )
  27. assert "contributions" in response2.results.messages[-1].content.lower(), "Agent should address follow-up question"
  28. assert not "who was aristotle" in response2.results.messages[-1].content.lower(), "Agent shouldn't repeat context explanation"
  29. def test_agent_rag_tool_usage(client, test_collection):
  30. """Test agent uses RAG tool for knowledge retrieval."""
  31. # Create unique document with specific content
  32. unique_id = str(uuid.uuid4())
  33. unique_content = f"Quantum entanglement is a physical phenomenon that occurs when pairs of particles interact. {unique_id}"
  34. doc_id = client.documents.create(raw_text=unique_content).results.document_id
  35. response = client.retrieval.agent(
  36. message={"role": "user", "content": f"According to the document, what is quantum entanglement? You must use the search_file_knowledge tool."},
  37. rag_tools=["search_file_knowledge"],
  38. rag_generation_config={"stream": False, "max_tokens_to_sample": 150},
  39. )
  40. assert "citations" in response.results.messages[-1].metadata, "Response should contain citations"
  41. assert len(response.results.messages[-1].metadata["citations"]) > 0, "Citations list should not be empty"
  42. assert str(doc_id) == response.results.messages[-1].metadata["citations"][0]["payload"]["document_id"], "Agent should use RAG tool to retrieve unique content"
  43. assert str("search_file_knowledge") == response.results.messages[-1].metadata["tool_calls"][-1]["name"], "Agent should use RAG tool to retrieve unique content"
  44. # Clean up
  45. client.documents.delete(id=doc_id)
  46. def test_agent_rag_tool_usage2(client, test_collection):
  47. """Test agent uses RAG tool for knowledge retrieval."""
  48. # Create unique document with specific content
  49. unique_id = str(uuid.uuid4())
  50. unique_content = f"Quantum entanglement is a physical phenomenon {unique_id} that occurs when pairs of particles interact."
  51. doc_id = client.documents.create(raw_text=unique_content).results.document_id
  52. response = client.retrieval.agent(
  53. message={"role": "user", "content": f"What is quantum entanglement? Mention {unique_id} in your response, be sure to both search your files and fetch the content."},
  54. rag_tools=["search_file_descriptions", "get_file_content"],
  55. rag_generation_config={"stream": False, "max_tokens_to_sample": 150},
  56. )
  57. # assert unique_id in response.results.messages[-1].content, "Agent should use RAG tool to retrieve unique content"
  58. # assert str(doc_id) == response.results.messages[-1].metadata["citations"][0]["payload"]["document_id"], "Agent should use RAG tool to retrieve unique content"
  59. assert str("search_file_descriptions") == response.results.messages[-1].metadata["tool_calls"][0]["name"], "Agent should use search_file_descriptions to retrieve unique content"
  60. assert str("get_file_content") == response.results.messages[-1].metadata["tool_calls"][1]["name"], "Agent should use get_file_content to retrieve unique content"
  61. # raise Exception("Test not implemented")
  62. # Clean up
  63. client.documents.delete(id=doc_id)
  64. # def test_agent_python_execution_tool(client, test_collection):
  65. # """Test agent uses Python execution tool for computation."""
  66. # response = client.retrieval.agent(
  67. # message={"role": "user", "content": "Calculate the factorial of 15! × 32 using Python. Return the result as a single string like 32812...."},
  68. # mode="research",
  69. # research_tools=["python_executor"],
  70. # research_generation_config={"stream": False, "max_tokens_to_sample": 200},
  71. # )
  72. # print(response)
  73. # assert "41845579776000" in response.results.messages[-1].content.replace(",",""), "Agent should execute Python code and return correct factorial result"
  74. # def test_agent_web_search_tool(client, monkeypatch):
  75. # """Test agent uses web search tool when appropriate."""
  76. # # Mock web search method to return predetermined results
  77. # def mock_web_search(*args, **kwargs):
  78. # return {"organic_results": [
  79. # {"title": "Recent COVID-19 Statistics", "link": "https://example.com/covid",
  80. # "snippet": "Latest COVID-19 statistics show declining cases worldwide."}
  81. # ]}
  82. # # Apply mock to appropriate method
  83. # monkeypatch.setattr("core.utils.serper.SerperClient.get_raw", mock_web_search)
  84. # response = client.retrieval.agent(
  85. # message={"role": "user", "content": "What are the latest COVID-19 statistics?"},
  86. # rag_tools=["web_search"],
  87. # rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  88. # )
  89. # print('response = ', response)
  90. # assert "declining cases" in response.results.messages[-1].content.lower(), "Agent should use web search tool for recent data"
  91. def test_research_agent_client(client):
  92. """Configure a client with research mode settings."""
  93. # This fixture helps avoid repetition in test setup
  94. return lambda message_content, tools=None: client.retrieval.agent(
  95. message={"role": "user", "content": message_content},
  96. mode="research",
  97. research_tools=tools or ["reasoning", "rag"],
  98. research_generation_config={"stream": False, "max_tokens_to_sample": 200},
  99. )
  100. def test_agent_respects_max_tokens(client, test_collection):
  101. """Test agent respects max_tokens configuration."""
  102. # Very small max_tokens
  103. short_response = client.retrieval.agent(
  104. message={"role": "user", "content": "Write a detailed essay about Aristotle's life and works."},
  105. rag_generation_config={"stream": False, "max_tokens_to_sample": 200},
  106. )
  107. # Larger max_tokens
  108. long_response = client.retrieval.agent(
  109. message={"role": "user", "content": "Write a detailed essay about Aristotle's life and works."},
  110. rag_generation_config={"stream": False, "max_tokens_to_sample": 500},
  111. )
  112. short_content = short_response.results.messages[-1].content
  113. long_content = long_response.results.messages[-1].content
  114. assert len(short_content) < len(long_content), "Short max_tokens should produce shorter response"
  115. assert len(short_content.split()) < 200, "Short response should be very brief"
  116. def test_agent_model_selection(client, test_collection):
  117. """Test agent works with different LLM models."""
  118. # Test with default model
  119. default_response = client.retrieval.agent(
  120. message={"role": "user", "content": "Who was Aristotle?"},
  121. rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  122. )
  123. # Test with specific model (if available in your setup)
  124. specific_model_response = client.retrieval.agent(
  125. message={"role": "user", "content": "Who was Aristotle?"},
  126. rag_generation_config={"stream": False, "max_tokens_to_sample": 100, "model": "openai/gpt-4.1"},
  127. )
  128. assert default_response.results.messages[-1].content, "Default model should provide response"
  129. assert specific_model_response.results.messages[-1].content, "Specific model should provide response"
  130. def test_agent_response_timing(client, test_collection):
  131. """Test agent response time is within acceptable limits."""
  132. import time
  133. start_time = time.time()
  134. response = client.retrieval.agent(
  135. message={"role": "user", "content": "Who was Aristotle?"},
  136. rag_generation_config={"stream": False, "max_tokens_to_sample": 100},
  137. )
  138. end_time = time.time()
  139. response_time = end_time - start_time
  140. assert response_time < 10, f"Agent response should complete within 10 seconds, took {response_time:.2f}s"
  141. def test_agent_handles_large_context(client):
  142. """Test agent handles large amount of context efficiently."""
  143. # Create a document with substantial content
  144. large_content = "Philosophy " * 2000 # ~16K chars
  145. doc_id = client.documents.create(raw_text=large_content).results.document_id
  146. start_time = time.time()
  147. response = client.retrieval.agent(
  148. message={"role": "user", "content": "Summarize everything you know about philosophy."},
  149. search_settings={"filters": {"document_id": {"$eq": str(doc_id)}}},
  150. rag_generation_config={"stream": False, "max_tokens_to_sample": 200},
  151. )
  152. end_time = time.time()
  153. response_time = end_time - start_time
  154. assert response.results.messages[-1].content, "Agent should produce a summary with large context"
  155. assert response_time < 20, f"Large context processing should complete in reasonable time, took {response_time:.2f}s"
  156. # Clean up
  157. client.documents.delete(id=doc_id)