123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- """
- Unit tests for the core R2RStreamingAgent functionality.
- These tests focus on the core functionality of the agent, separate from
- citation-specific behavior which is tested in test_agent_citations.py.
- """
- import pytest
- import asyncio
- import json
- import re
- from unittest.mock import MagicMock, patch, AsyncMock
- from typing import Dict, List, Tuple, Any, AsyncGenerator
- import pytest_asyncio
- from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
- from core.utils import CitationTracker, SearchResultsCollector, SSEFormatter
- from core.agent.base import R2RStreamingAgent
- # Import mock classes from conftest
- from conftest import (
- MockDatabaseProvider,
- MockLLMProvider,
- MockR2RStreamingAgent,
- MockSearchResultsCollector,
- collect_stream_output
- )
- @pytest.mark.asyncio
- async def test_streaming_agent_functionality():
- """Test basic functionality of the streaming agent."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a test response",
- citations=[]
- )
- db_provider = MockDatabaseProvider()
- # Create mock search results collector
- search_results_collector = MockSearchResultsCollector({})
- # Create agent
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Set the search results collector
- agent.search_results_collector = search_results_collector
- # Test a simple query
- messages = [Message(role="user", content="Test query")]
- # Run the agent
- stream = agent.arun(messages=messages)
- output = await collect_stream_output(stream)
- # Verify response
- message_events = [line for line in output if 'event: message' in line]
- assert len(message_events) > 0, "Message event should be emitted"
- # Verify final answer
- final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
- assert len(final_answer_events) > 0, "Final answer event should be emitted"
- # Verify done event
- done_events = [line for line in output if 'event: done' in line]
- assert len(done_events) > 0, "Done event should be emitted"
- @pytest.mark.asyncio
- async def test_agent_handles_multiple_messages():
- """Test agent handles conversation with multiple messages."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a response to multiple messages",
- citations=[]
- )
- db_provider = MockDatabaseProvider()
- # Create mock search results collector
- search_results = {
- "abc1234": {
- "document_id": "doc_abc1234",
- "text": "This is document text for abc1234",
- "metadata": {"source": "source_abc1234"}
- },
- "def5678": {
- "document_id": "doc_def5678",
- "text": "This is document text for def5678",
- "metadata": {"source": "source_def5678"}
- }
- }
- search_results_collector = MockSearchResultsCollector(search_results)
- # Create agent
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Set the search results collector
- agent.search_results_collector = search_results_collector
- # Test with multiple messages
- messages = [
- Message(role="system", content="You are a helpful assistant"),
- Message(role="user", content="First question"),
- Message(role="assistant", content="First answer"),
- Message(role="user", content="Follow-up question")
- ]
- # Run the agent
- stream = agent.arun(messages=messages)
- output = await collect_stream_output(stream)
- # Verify response
- message_events = [line for line in output if 'event: message' in line]
- assert len(message_events) > 0, "Message event should be emitted"
- # After running, check that conversation has the new assistant response
- # Note: MockR2RStreamingAgent._setup adds a default system message
- # and then our messages are added, plus the agent's response
- assert len(agent.conversation.messages) == 6, "Conversation should have correct number of messages"
- # The last message should be the assistant's response
- assert agent.conversation.messages[-1].role == "assistant", "Last message should be from assistant"
- # We should have two system messages (default + our custom one)
- system_messages = [m for m in agent.conversation.messages if m.role == "system"]
- assert len(system_messages) == 2, "Should have two system messages"
- @pytest.mark.asyncio
- async def test_agent_event_format():
- """Test the format of events emitted by the agent."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a test of event formatting",
- citations=[]
- )
- db_provider = MockDatabaseProvider()
- # Create mock search results collector
- search_results_collector = MockSearchResultsCollector({})
- # Create agent
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Set the search results collector
- agent.search_results_collector = search_results_collector
- # Test a simple query
- messages = [Message(role="user", content="Test query")]
- # Run the agent
- stream = agent.arun(messages=messages)
- output = await collect_stream_output(stream)
- # Check message event format
- message_events = [line for line in output if 'event: message' in line]
- assert len(message_events) > 0, "Message event should be emitted"
- data_part = message_events[0].split('data: ')[1] if 'data: ' in message_events[0] else ""
- try:
- data = json.loads(data_part)
- assert "content" in data, "Message event should include content"
- except json.JSONDecodeError:
- assert False, "Message event data should be valid JSON"
- # Check final answer event format
- final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
- assert len(final_answer_events) > 0, "Final answer event should be emitted"
- data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
- try:
- data = json.loads(data_part)
- assert "id" in data, "Final answer event should include ID"
- assert "object" in data, "Final answer event should include object type"
- assert "generated_answer" in data, "Final answer event should include generated answer"
- except json.JSONDecodeError:
- assert False, "Final answer event data should be valid JSON"
- @pytest.mark.asyncio
- async def test_final_answer_event_format():
- """Test that the final answer event has the expected format and content."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a test final answer",
- citations=[]
- )
- db_provider = MockDatabaseProvider()
- # Create mock search results collector
- search_results_collector = MockSearchResultsCollector({})
- # Create agent
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Set the search results collector
- agent.search_results_collector = search_results_collector
- # Test a simple query
- messages = [Message(role="user", content="Test query")]
- # Run the agent
- stream = agent.arun(messages=messages)
- output = await collect_stream_output(stream)
- # Extract and verify final answer event
- final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
- assert len(final_answer_events) > 0, "Final answer event should be emitted"
- data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
- try:
- data = json.loads(data_part)
- assert data["id"] == "msg_final", "Final answer ID should be msg_final"
- assert data["object"] == "agent.final_answer", "Final answer object should be agent.final_answer"
- assert "generated_answer" in data, "Final answer should include generated_answer"
- assert "citations" in data, "Final answer should include citations field"
- except json.JSONDecodeError:
- assert False, "Final answer event data should be valid JSON"
- @pytest.mark.asyncio
- async def test_conversation_message_format():
- """Test that the conversation includes properly formatted assistant messages."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a test message",
- citations=[]
- )
- db_provider = MockDatabaseProvider()
- # Create mock search results collector
- search_results = {
- "abc1234": {
- "document_id": "doc_abc1234",
- "text": "This is document text for abc1234",
- "metadata": {"source": "source_abc1234"}
- },
- "def5678": {
- "document_id": "doc_def5678",
- "text": "This is document text for def5678",
- "metadata": {"source": "source_def5678"}
- }
- }
- search_results_collector = MockSearchResultsCollector(search_results)
- # Create agent
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Set the search results collector
- agent.search_results_collector = search_results_collector
- # Test a simple query
- messages = [Message(role="user", content="Test query")]
- # Run the agent
- stream = agent.arun(messages=messages)
- await collect_stream_output(stream)
- # Get the last message from the conversation
- last_message = agent.conversation.messages[-1]
- # Verify message format - note that MockR2RStreamingAgent uses a hardcoded response
- assert last_message.role == "assistant", "Last message should be from assistant"
- assert "This is a test response with citations" in last_message.content, "Message content should include response"
- assert "metadata" in last_message.dict(), "Message should include metadata"
- assert "citations" in last_message.metadata, "Message metadata should include citations"
|