test_agent.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. """
  2. Unit tests for the core R2RStreamingAgent functionality.
  3. These tests focus on the core functionality of the agent, separate from
  4. citation-specific behavior which is tested in test_agent_citations.py.
  5. """
  6. import pytest
  7. import asyncio
  8. import json
  9. import re
  10. from unittest.mock import MagicMock, patch, AsyncMock
  11. from typing import Dict, List, Tuple, Any, AsyncGenerator
  12. import pytest_asyncio
  13. from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
  14. from core.utils import CitationTracker, SearchResultsCollector, SSEFormatter
  15. from core.agent.base import R2RStreamingAgent
  16. # Import mock classes from conftest
  17. from conftest import (
  18. MockDatabaseProvider,
  19. MockLLMProvider,
  20. MockR2RStreamingAgent,
  21. MockSearchResultsCollector,
  22. collect_stream_output
  23. )
  24. @pytest.mark.asyncio
  25. async def test_streaming_agent_functionality():
  26. """Test basic functionality of the streaming agent."""
  27. # Create mock config
  28. config = MagicMock()
  29. config.stream = True
  30. # Create mock providers
  31. llm_provider = MockLLMProvider(
  32. response_content="This is a test response",
  33. citations=[]
  34. )
  35. db_provider = MockDatabaseProvider()
  36. # Create mock search results collector
  37. search_results_collector = MockSearchResultsCollector({})
  38. # Create agent
  39. agent = MockR2RStreamingAgent(
  40. database_provider=db_provider,
  41. llm_provider=llm_provider,
  42. config=config,
  43. rag_generation_config=GenerationConfig(model="test/model")
  44. )
  45. # Set the search results collector
  46. agent.search_results_collector = search_results_collector
  47. # Test a simple query
  48. messages = [Message(role="user", content="Test query")]
  49. # Run the agent
  50. stream = agent.arun(messages=messages)
  51. output = await collect_stream_output(stream)
  52. # Verify response
  53. message_events = [line for line in output if 'event: message' in line]
  54. assert len(message_events) > 0, "Message event should be emitted"
  55. # Verify final answer
  56. final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
  57. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  58. # Verify done event
  59. done_events = [line for line in output if 'event: done' in line]
  60. assert len(done_events) > 0, "Done event should be emitted"
  61. @pytest.mark.asyncio
  62. async def test_agent_handles_multiple_messages():
  63. """Test agent handles conversation with multiple messages."""
  64. # Create mock config
  65. config = MagicMock()
  66. config.stream = True
  67. # Create mock providers
  68. llm_provider = MockLLMProvider(
  69. response_content="This is a response to multiple messages",
  70. citations=[]
  71. )
  72. db_provider = MockDatabaseProvider()
  73. # Create mock search results collector
  74. search_results = {
  75. "abc1234": {
  76. "document_id": "doc_abc1234",
  77. "text": "This is document text for abc1234",
  78. "metadata": {"source": "source_abc1234"}
  79. },
  80. "def5678": {
  81. "document_id": "doc_def5678",
  82. "text": "This is document text for def5678",
  83. "metadata": {"source": "source_def5678"}
  84. }
  85. }
  86. search_results_collector = MockSearchResultsCollector(search_results)
  87. # Create agent
  88. agent = MockR2RStreamingAgent(
  89. database_provider=db_provider,
  90. llm_provider=llm_provider,
  91. config=config,
  92. rag_generation_config=GenerationConfig(model="test/model")
  93. )
  94. # Set the search results collector
  95. agent.search_results_collector = search_results_collector
  96. # Test with multiple messages
  97. messages = [
  98. Message(role="system", content="You are a helpful assistant"),
  99. Message(role="user", content="First question"),
  100. Message(role="assistant", content="First answer"),
  101. Message(role="user", content="Follow-up question")
  102. ]
  103. # Run the agent
  104. stream = agent.arun(messages=messages)
  105. output = await collect_stream_output(stream)
  106. # Verify response
  107. message_events = [line for line in output if 'event: message' in line]
  108. assert len(message_events) > 0, "Message event should be emitted"
  109. # After running, check that conversation has the new assistant response
  110. # Note: MockR2RStreamingAgent._setup adds a default system message
  111. # and then our messages are added, plus the agent's response
  112. assert len(agent.conversation.messages) == 6, "Conversation should have correct number of messages"
  113. # The last message should be the assistant's response
  114. assert agent.conversation.messages[-1].role == "assistant", "Last message should be from assistant"
  115. # We should have two system messages (default + our custom one)
  116. system_messages = [m for m in agent.conversation.messages if m.role == "system"]
  117. assert len(system_messages) == 2, "Should have two system messages"
  118. @pytest.mark.asyncio
  119. async def test_agent_event_format():
  120. """Test the format of events emitted by the agent."""
  121. # Create mock config
  122. config = MagicMock()
  123. config.stream = True
  124. # Create mock providers
  125. llm_provider = MockLLMProvider(
  126. response_content="This is a test of event formatting",
  127. citations=[]
  128. )
  129. db_provider = MockDatabaseProvider()
  130. # Create mock search results collector
  131. search_results_collector = MockSearchResultsCollector({})
  132. # Create agent
  133. agent = MockR2RStreamingAgent(
  134. database_provider=db_provider,
  135. llm_provider=llm_provider,
  136. config=config,
  137. rag_generation_config=GenerationConfig(model="test/model")
  138. )
  139. # Set the search results collector
  140. agent.search_results_collector = search_results_collector
  141. # Test a simple query
  142. messages = [Message(role="user", content="Test query")]
  143. # Run the agent
  144. stream = agent.arun(messages=messages)
  145. output = await collect_stream_output(stream)
  146. # Check message event format
  147. message_events = [line for line in output if 'event: message' in line]
  148. assert len(message_events) > 0, "Message event should be emitted"
  149. data_part = message_events[0].split('data: ')[1] if 'data: ' in message_events[0] else ""
  150. try:
  151. data = json.loads(data_part)
  152. assert "content" in data, "Message event should include content"
  153. except json.JSONDecodeError:
  154. assert False, "Message event data should be valid JSON"
  155. # Check final answer event format
  156. final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
  157. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  158. data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
  159. try:
  160. data = json.loads(data_part)
  161. assert "id" in data, "Final answer event should include ID"
  162. assert "object" in data, "Final answer event should include object type"
  163. assert "generated_answer" in data, "Final answer event should include generated answer"
  164. except json.JSONDecodeError:
  165. assert False, "Final answer event data should be valid JSON"
  166. @pytest.mark.asyncio
  167. async def test_final_answer_event_format():
  168. """Test that the final answer event has the expected format and content."""
  169. # Create mock config
  170. config = MagicMock()
  171. config.stream = True
  172. # Create mock providers
  173. llm_provider = MockLLMProvider(
  174. response_content="This is a test final answer",
  175. citations=[]
  176. )
  177. db_provider = MockDatabaseProvider()
  178. # Create mock search results collector
  179. search_results_collector = MockSearchResultsCollector({})
  180. # Create agent
  181. agent = MockR2RStreamingAgent(
  182. database_provider=db_provider,
  183. llm_provider=llm_provider,
  184. config=config,
  185. rag_generation_config=GenerationConfig(model="test/model")
  186. )
  187. # Set the search results collector
  188. agent.search_results_collector = search_results_collector
  189. # Test a simple query
  190. messages = [Message(role="user", content="Test query")]
  191. # Run the agent
  192. stream = agent.arun(messages=messages)
  193. output = await collect_stream_output(stream)
  194. # Extract and verify final answer event
  195. final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
  196. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  197. data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
  198. try:
  199. data = json.loads(data_part)
  200. assert data["id"] == "msg_final", "Final answer ID should be msg_final"
  201. assert data["object"] == "agent.final_answer", "Final answer object should be agent.final_answer"
  202. assert "generated_answer" in data, "Final answer should include generated_answer"
  203. assert "citations" in data, "Final answer should include citations field"
  204. except json.JSONDecodeError:
  205. assert False, "Final answer event data should be valid JSON"
  206. @pytest.mark.asyncio
  207. async def test_conversation_message_format():
  208. """Test that the conversation includes properly formatted assistant messages."""
  209. # Create mock config
  210. config = MagicMock()
  211. config.stream = True
  212. # Create mock providers
  213. llm_provider = MockLLMProvider(
  214. response_content="This is a test message",
  215. citations=[]
  216. )
  217. db_provider = MockDatabaseProvider()
  218. # Create mock search results collector
  219. search_results = {
  220. "abc1234": {
  221. "document_id": "doc_abc1234",
  222. "text": "This is document text for abc1234",
  223. "metadata": {"source": "source_abc1234"}
  224. },
  225. "def5678": {
  226. "document_id": "doc_def5678",
  227. "text": "This is document text for def5678",
  228. "metadata": {"source": "source_def5678"}
  229. }
  230. }
  231. search_results_collector = MockSearchResultsCollector(search_results)
  232. # Create agent
  233. agent = MockR2RStreamingAgent(
  234. database_provider=db_provider,
  235. llm_provider=llm_provider,
  236. config=config,
  237. rag_generation_config=GenerationConfig(model="test/model")
  238. )
  239. # Set the search results collector
  240. agent.search_results_collector = search_results_collector
  241. # Test a simple query
  242. messages = [Message(role="user", content="Test query")]
  243. # Run the agent
  244. stream = agent.arun(messages=messages)
  245. await collect_stream_output(stream)
  246. # Get the last message from the conversation
  247. last_message = agent.conversation.messages[-1]
  248. # Verify message format - note that MockR2RStreamingAgent uses a hardcoded response
  249. assert last_message.role == "assistant", "Last message should be from assistant"
  250. assert "This is a test response with citations" in last_message.content, "Message content should include response"
  251. assert "metadata" in last_message.dict(), "Message should include metadata"
  252. assert "citations" in last_message.metadata, "Message metadata should include citations"