test_agent_citations_old.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137
  1. """
  2. Unit tests for citation extraction and propagation in the R2RStreamingAgent.
  3. These tests focus specifically on citation-related functionality:
  4. - Citation extraction from text
  5. - Citation tracking during streaming
  6. - Citation event emission
  7. - Citation formatting and propagation
  8. - Citation edge cases and validation
  9. """
  10. import pytest
  11. import asyncio
  12. import json
  13. import re
  14. from unittest.mock import MagicMock, patch, AsyncMock
  15. from typing import Dict, List, Tuple, Any, AsyncGenerator
  16. import pytest_asyncio
  17. from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
  18. from core.utils import CitationTracker, extract_citations, extract_citation_spans
  19. from core.agent.base import R2RStreamingAgent
  20. # Import mock classes from conftest
  21. from conftest import (
  22. MockDatabaseProvider,
  23. MockLLMProvider,
  24. MockR2RStreamingAgent,
  25. MockSearchResultsCollector,
  26. collect_stream_output
  27. )
  28. class MockLLMProvider:
  29. """Mock LLM provider for testing."""
  30. def __init__(self, response_content=None, citations=None):
  31. self.response_content = response_content or "This is a response"
  32. self.citations = citations or []
  33. async def aget_completion(self, messages, generation_config):
  34. """Mock synchronous completion."""
  35. content = self.response_content
  36. for citation in self.citations:
  37. content += f" [{citation}]"
  38. mock_response = MagicMock(spec=LLMChatCompletion)
  39. mock_response.choices = [MagicMock()]
  40. mock_response.choices[0].message = MagicMock()
  41. mock_response.choices[0].message.content = content
  42. mock_response.choices[0].finish_reason = "stop"
  43. return mock_response
  44. async def aget_completion_stream(self, messages, generation_config):
  45. """Mock streaming completion."""
  46. content = self.response_content
  47. for citation in self.citations:
  48. content += f" [{citation}]"
  49. # Simulate streaming by yielding one character at a time
  50. for i in range(len(content)):
  51. chunk = MagicMock(spec=LLMChatCompletionChunk)
  52. chunk.choices = [MagicMock()]
  53. chunk.choices[0].delta = MagicMock()
  54. chunk.choices[0].delta.content = content[i]
  55. chunk.choices[0].finish_reason = None
  56. yield chunk
  57. # Final chunk with finish_reason="stop"
  58. final_chunk = MagicMock(spec=LLMChatCompletionChunk)
  59. final_chunk.choices = [MagicMock()]
  60. final_chunk.choices[0].delta = MagicMock()
  61. final_chunk.choices[0].delta.content = ""
  62. final_chunk.choices[0].finish_reason = "stop"
  63. yield final_chunk
  64. class MockPromptsHandler:
  65. """Mock prompts handler for testing."""
  66. async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
  67. """Return a mock system prompt."""
  68. return "You are a helpful assistant that provides well-sourced information."
  69. class MockDatabaseProvider:
  70. """Mock database provider for testing."""
  71. def __init__(self):
  72. # Add a prompts_handler attribute to prevent AttributeError
  73. self.prompts_handler = MockPromptsHandler()
  74. async def acreate_conversation(self, *args, **kwargs):
  75. return {"id": "conv_12345"}
  76. async def aupdate_conversation(self, *args, **kwargs):
  77. return True
  78. async def acreate_message(self, *args, **kwargs):
  79. return {"id": "msg_12345"}
  80. class MockSearchResultsCollector:
  81. """Mock search results collector for testing."""
  82. def __init__(self, results=None):
  83. self.results = results or {}
  84. def find_by_short_id(self, short_id):
  85. return self.results.get(short_id, {
  86. "document_id": f"doc_{short_id}",
  87. "text": f"This is document text for {short_id}",
  88. "metadata": {"source": f"source_{short_id}"}
  89. })
  90. # Create a concrete implementation of R2RStreamingAgent for testing
  91. class MockR2RStreamingAgent(R2RStreamingAgent):
  92. """Mock streaming agent for testing that implements the abstract method."""
  93. # Regex pattern for citations, copied from the actual agent
  94. BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
  95. SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
  96. def _register_tools(self):
  97. """Implement the abstract method with a no-op version."""
  98. pass
  99. async def _setup(self, system_instruction=None, *args, **kwargs):
  100. """Override _setup to simplify initialization and avoid external dependencies."""
  101. # Use a simple system message instead of fetching from database
  102. system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
  103. # Add system message to conversation
  104. await self.conversation.add_message(
  105. Message(role="system", content=system_content)
  106. )
  107. def _format_sse_event(self, event_type, data):
  108. """Format an SSE event manually."""
  109. return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
  110. async def arun(
  111. self,
  112. system_instruction: str = None,
  113. messages: list[Message] = None,
  114. *args,
  115. **kwargs,
  116. ) -> AsyncGenerator[str, None]:
  117. """
  118. Simplified version of arun that focuses on citation handling for testing.
  119. """
  120. await self._setup(system_instruction)
  121. if messages:
  122. for m in messages:
  123. await self.conversation.add_message(m)
  124. # Initialize citation tracker
  125. citation_tracker = CitationTracker()
  126. citation_payloads = {}
  127. # Track streaming citations for final persistence
  128. self.streaming_citations = []
  129. # Get the LLM response with citations
  130. response_content = "This is a test response with citations"
  131. response_content += " [abc1234] [def5678]"
  132. # Yield an initial message event with the start of the text
  133. yield self._format_sse_event("message", {"content": response_content})
  134. # Manually extract and emit citation events
  135. # This is a simpler approach than the character-by-character approach
  136. citation_spans = extract_citation_spans(response_content)
  137. # Process the citations
  138. for cid, spans in citation_spans.items():
  139. for span in spans:
  140. # Check if the span is new and record it
  141. if citation_tracker.is_new_span(cid, span):
  142. # Look up the source document for this citation
  143. source_doc = self.search_results_collector.find_by_short_id(cid)
  144. # Create citation payload
  145. citation_payload = {
  146. "document_id": source_doc.get("document_id", f"doc_{cid}"),
  147. "text": source_doc.get("text", f"This is document text for {cid}"),
  148. "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
  149. }
  150. # Store the payload by citation ID
  151. citation_payloads[cid] = citation_payload
  152. # Track for persistence
  153. self.streaming_citations.append({
  154. "id": cid,
  155. "span": {"start": span[0], "end": span[1]},
  156. "payload": citation_payload
  157. })
  158. # Emit citation event
  159. citation_event = {
  160. "id": cid,
  161. "object": "citation",
  162. "span": {"start": span[0], "end": span[1]},
  163. "payload": citation_payload
  164. }
  165. yield self._format_sse_event("citation", citation_event)
  166. # Add assistant message with citation metadata to conversation
  167. await self.conversation.add_message(
  168. Message(
  169. role="assistant",
  170. content=response_content,
  171. metadata={"citations": self.streaming_citations}
  172. )
  173. )
  174. # Prepare consolidated citations for final answer
  175. consolidated_citations = []
  176. # Group citations by ID with all their spans
  177. for cid, spans in citation_tracker.get_all_spans().items():
  178. if cid in citation_payloads:
  179. consolidated_citations.append({
  180. "id": cid,
  181. "object": "citation",
  182. "spans": [{"start": s[0], "end": s[1]} for s in spans],
  183. "payload": citation_payloads[cid]
  184. })
  185. # Create and emit final answer event
  186. final_evt_payload = {
  187. "id": "msg_final",
  188. "object": "agent.final_answer",
  189. "generated_answer": response_content,
  190. "citations": consolidated_citations
  191. }
  192. # Manually format the final answer event
  193. yield self._format_sse_event("agent.final_answer", final_evt_payload)
  194. # Signal the end of the SSE stream
  195. yield "event: done\ndata: {}\n\n"
  196. @pytest.fixture
  197. def mock_streaming_agent():
  198. """Create a streaming agent with mocked dependencies."""
  199. # Create mock config
  200. config = MagicMock()
  201. config.stream = True
  202. config.max_iterations = 3
  203. # Create mock providers
  204. llm_provider = MockLLMProvider(
  205. response_content="This is a test response with citations",
  206. citations=["abc1234", "def5678"]
  207. )
  208. db_provider = MockDatabaseProvider()
  209. # Create agent with mocked dependencies using our concrete implementation
  210. agent = MockR2RStreamingAgent(
  211. database_provider=db_provider,
  212. llm_provider=llm_provider,
  213. config=config,
  214. rag_generation_config=GenerationConfig(model="test/model")
  215. )
  216. # Replace the search results collector with our mock
  217. agent.search_results_collector = MockSearchResultsCollector({
  218. "abc1234": {
  219. "document_id": "doc_abc1234",
  220. "text": "This is document text for abc1234",
  221. "metadata": {"source": "source_abc1234"}
  222. },
  223. "def5678": {
  224. "document_id": "doc_def5678",
  225. "text": "This is document text for def5678",
  226. "metadata": {"source": "source_def5678"}
  227. }
  228. })
  229. return agent
  230. async def collect_stream_output(stream):
  231. """Collect all output from a stream into a list."""
  232. output = []
  233. async for event in stream:
  234. output.append(event)
  235. return output
  236. def test_extract_citations_from_response():
  237. """Test that citations are extracted from LLM responses."""
  238. response_text = "This is a response with a citation [abc1234]."
  239. # Use the utility function directly
  240. citations = extract_citations(response_text)
  241. assert "abc1234" in citations, "Citation should be extracted from response"
  242. @pytest.mark.asyncio
  243. async def test_streaming_agent_citation_extraction(mock_streaming_agent):
  244. """Test that streaming agent extracts citations from streamed content."""
  245. # Run the agent
  246. messages = [Message(role="user", content="Test query")]
  247. # We need to run this in a coroutine
  248. stream = mock_streaming_agent.arun(messages=messages)
  249. output = await collect_stream_output(stream)
  250. # Look for citation events in the output
  251. citation_events = [
  252. line for line in output
  253. if 'event: citation' in line
  254. ]
  255. assert len(citation_events) > 0, "Citation events should be emitted"
  256. # Check citation IDs in events
  257. citation_abc = any('abc1234' in event for event in citation_events)
  258. citation_def = any('def5678' in event for event in citation_events)
  259. assert citation_abc, "Citation abc1234 should be found in stream output"
  260. assert citation_def, "Citation def5678 should be found in stream output"
  261. @pytest.mark.asyncio
  262. async def test_citation_tracker_during_streaming(mock_streaming_agent):
  263. """Test that CitationTracker correctly tracks processed citations during streaming."""
  264. # We need to patch the is_new_span method to verify it's being used correctly
  265. # Use autospec=True to ensure the method signature is preserved
  266. with patch('core.utils.CitationTracker.is_new_span', autospec=True) as mock_is_new_span:
  267. # Configure the mock to return True so citations will be processed
  268. mock_is_new_span.return_value = True
  269. messages = [Message(role="user", content="Test query")]
  270. # Run the agent
  271. stream = mock_streaming_agent.arun(messages=messages)
  272. output = await collect_stream_output(stream)
  273. # Verify that CitationTracker.is_new_span method was called
  274. assert mock_is_new_span.call_count > 0, "is_new_span should be called to track citation spans"
  275. @pytest.mark.asyncio
  276. async def test_final_answer_includes_consolidated_citations(mock_streaming_agent):
  277. """Test that the final answer includes consolidated citations."""
  278. messages = [Message(role="user", content="Test query")]
  279. # Run the agent
  280. stream = mock_streaming_agent.arun(messages=messages)
  281. output = await collect_stream_output(stream)
  282. # Look for final answer event in the output
  283. final_answer_events = [
  284. line for line in output
  285. if 'event: agent.final_answer' in line
  286. ]
  287. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  288. # Parse the event to check for citations
  289. for event in final_answer_events:
  290. data_part = event.split('data: ')[1] if 'data: ' in event else event
  291. try:
  292. data = json.loads(data_part)
  293. if 'citations' in data:
  294. assert len(data['citations']) > 0, "Final answer should include citations"
  295. citation_ids = [citation.get('id') for citation in data['citations']]
  296. assert 'abc1234' in citation_ids or 'def5678' in citation_ids, "Known citation IDs should be included"
  297. except json.JSONDecodeError:
  298. continue
  299. @pytest.mark.asyncio
  300. async def test_conversation_message_includes_citation_metadata(mock_streaming_agent):
  301. """Test that conversation messages include citation metadata."""
  302. with patch.object(mock_streaming_agent.conversation, 'add_message', wraps=mock_streaming_agent.conversation.add_message) as mock_add_message:
  303. messages = [Message(role="user", content="Test query")]
  304. # Run the agent
  305. stream = mock_streaming_agent.arun(messages=messages)
  306. output = await collect_stream_output(stream)
  307. # Check that add_message was called with citation metadata
  308. citation_calls = 0
  309. for call in mock_add_message.call_args_list:
  310. args, kwargs = call
  311. if args and isinstance(args[0], Message):
  312. message = args[0]
  313. if message.role == 'assistant' and message.metadata and 'citations' in message.metadata:
  314. citation_calls += 1
  315. assert citation_calls > 0, "At least one assistant message should include citation metadata"
  316. @pytest.mark.asyncio
  317. async def test_multiple_citations_for_same_source(mock_streaming_agent):
  318. """Test handling of multiple citations for the same source document."""
  319. # Create a custom citation tracker that we can control
  320. citation_tracker = CitationTracker()
  321. # Create a custom MockR2RStreamingAgent with our controlled citation tracker
  322. with patch('core.utils.CitationTracker', return_value=citation_tracker):
  323. custom_agent = mock_streaming_agent
  324. # Modify the arun method to include repeated citations for the same source
  325. original_arun = custom_agent.arun
  326. async def custom_arun(*args, **kwargs):
  327. """Custom arun that includes repeated citations for the same source."""
  328. # Setup like the original
  329. await custom_agent._setup(kwargs.get('system_instruction'))
  330. messages = kwargs.get('messages', [])
  331. if messages:
  332. for m in messages:
  333. await custom_agent.conversation.add_message(m)
  334. # Initialize payloads dict for tracking
  335. citation_payloads = {}
  336. # Track streaming citations for final persistence
  337. custom_agent.streaming_citations = []
  338. # Create text with multiple citations to the same source
  339. response_content = "This text has multiple citations to the same source: [abc1234] and again here [abc1234]."
  340. # Yield the message event
  341. yield custom_agent._format_sse_event("message", {"content": response_content})
  342. # Manually extract and emit citation events
  343. # This is a simpler approach than the character-by-character approach
  344. citation_spans = extract_citation_spans(response_content)
  345. # Process the citations
  346. for cid, spans in citation_spans.items():
  347. for span in spans:
  348. # Mark as processed in the tracker
  349. citation_tracker.is_new_span(cid, span)
  350. # Look up the source document for this citation
  351. source_doc = custom_agent.search_results_collector.find_by_short_id(cid)
  352. # Create citation payload
  353. citation_payload = {
  354. "document_id": source_doc.get("document_id", f"doc_{cid}"),
  355. "text": source_doc.get("text", f"This is document text for {cid}"),
  356. "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
  357. }
  358. # Store the payload
  359. citation_payloads[cid] = citation_payload
  360. # Track for persistence
  361. custom_agent.streaming_citations.append({
  362. "id": cid,
  363. "span": {"start": span[0], "end": span[1]},
  364. "payload": citation_payload
  365. })
  366. # Emit citation event
  367. citation_event = {
  368. "id": cid,
  369. "object": "citation",
  370. "span": {"start": span[0], "end": span[1]},
  371. "payload": citation_payload
  372. }
  373. yield custom_agent._format_sse_event("citation", citation_event)
  374. # Add assistant message with citation metadata to conversation
  375. await custom_agent.conversation.add_message(
  376. Message(
  377. role="assistant",
  378. content=response_content,
  379. metadata={"citations": custom_agent.streaming_citations}
  380. )
  381. )
  382. # Prepare consolidated citations for final answer
  383. consolidated_citations = []
  384. # Group citations by ID with all their spans
  385. for cid, spans in citation_tracker.get_all_spans().items():
  386. if cid in citation_payloads:
  387. consolidated_citations.append({
  388. "id": cid,
  389. "object": "citation",
  390. "spans": [{"start": s[0], "end": s[1]} for s in spans],
  391. "payload": citation_payloads[cid]
  392. })
  393. # Create and emit final answer event
  394. final_evt_payload = {
  395. "id": "msg_final",
  396. "object": "agent.final_answer",
  397. "generated_answer": response_content,
  398. "citations": consolidated_citations
  399. }
  400. yield custom_agent._format_sse_event("agent.final_answer", final_evt_payload)
  401. # Signal the end of the SSE stream
  402. yield "event: done\ndata: {}\n\n"
  403. # Apply the custom arun method
  404. with patch.object(custom_agent, 'arun', custom_arun):
  405. messages = [Message(role="user", content="Test query")]
  406. # Run the agent with overlapping citations
  407. stream = custom_agent.arun(messages=messages)
  408. output = await collect_stream_output(stream)
  409. # Count citation events for abc1234
  410. citation_abc_events = [
  411. line for line in output
  412. if 'event: citation' in line and 'abc1234' in line
  413. ]
  414. # There should be at least 2 citations for abc1234 (the original and our added one)
  415. assert len(citation_abc_events) >= 2, "Should emit multiple citation events for the same source"
  416. # Check the final answer to ensure spans were consolidated
  417. final_answer_events = [
  418. line for line in output
  419. if 'event: agent.final_answer' in line
  420. ]
  421. for event in final_answer_events:
  422. data_part = event.split('data: ')[1] if 'data: ' in event else event
  423. try:
  424. data = json.loads(data_part)
  425. if 'citations' in data:
  426. # Find the citation for abc1234
  427. abc_citation = next((citation for citation in data['citations'] if citation.get('id') == 'abc1234'), None)
  428. if abc_citation:
  429. # It should have multiple spans
  430. assert abc_citation.get('spans') and len(abc_citation['spans']) >= 2, "Citation should have multiple spans consolidated"
  431. except json.JSONDecodeError:
  432. continue
  433. @pytest.mark.asyncio
  434. async def test_citation_consolidation_logic(mock_streaming_agent):
  435. """Test that citation consolidation properly groups spans by citation ID."""
  436. # Patch the get_all_spans method to return a controlled set of spans
  437. citation_tracker = CitationTracker()
  438. # Add spans for multiple citations
  439. citation_tracker.is_new_span("abc1234", (10, 20))
  440. citation_tracker.is_new_span("abc1234", (30, 40))
  441. citation_tracker.is_new_span("def5678", (50, 60))
  442. citation_tracker.is_new_span("ghi9012", (70, 80))
  443. citation_tracker.is_new_span("ghi9012", (90, 100))
  444. # Create a custom mock agent that uses our pre-populated citation tracker
  445. with patch('core.utils.CitationTracker', return_value=citation_tracker):
  446. # Create a fresh agent with our mocked citation tracker
  447. new_agent = mock_streaming_agent
  448. messages = [Message(role="user", content="Test query")]
  449. # Run the agent
  450. stream = new_agent.arun(messages=messages)
  451. output = await collect_stream_output(stream)
  452. # Look for the final answer event
  453. final_answer_events = [
  454. line for line in output
  455. if 'event: agent.final_answer' in line
  456. ]
  457. # Verify consolidation in final answer
  458. for event in final_answer_events:
  459. data_part = event.split('data: ')[1] if 'data: ' in event else event
  460. try:
  461. data = json.loads(data_part)
  462. if 'citations' in data:
  463. # There should be at least 2 citations (from our mock agent implementation)
  464. assert len(data['citations']) >= 2, "Should include multiple citation objects"
  465. # Check spans for each citation
  466. for citation in data['citations']:
  467. cid = citation.get('id')
  468. if cid == 'abc1234':
  469. # Spans should be consolidated for abc1234
  470. spans = citation.get('spans', [])
  471. assert len(spans) >= 1, f"Citation {cid} should have spans"
  472. except json.JSONDecodeError:
  473. continue
  474. @pytest.mark.asyncio
  475. async def test_citation_event_format(mock_streaming_agent):
  476. """Test that citation events follow the expected format."""
  477. messages = [Message(role="user", content="Test query")]
  478. # Run the agent
  479. stream = mock_streaming_agent.arun(messages=messages)
  480. output = await collect_stream_output(stream)
  481. # Extract citation events
  482. citation_events = [
  483. line for line in output
  484. if 'event: citation' in line
  485. ]
  486. assert len(citation_events) > 0, "Citation events should be emitted"
  487. # Check the format of each citation event
  488. for event in citation_events:
  489. # Should have 'event: citation' and 'data: {...}'
  490. assert 'event: citation' in event, "Event type should be 'citation'"
  491. assert 'data: ' in event, "Event should have data payload"
  492. # Parse the data payload
  493. data_part = event.split('data: ')[1] if 'data: ' in event else event
  494. try:
  495. data = json.loads(data_part)
  496. # Check required fields
  497. assert 'id' in data, "Citation event should have an 'id'"
  498. assert 'object' in data and data['object'] == 'citation', "Event object should be 'citation'"
  499. assert 'span' in data, "Citation event should have a 'span'"
  500. assert 'start' in data['span'] and 'end' in data['span'], "Span should have 'start' and 'end'"
  501. assert 'payload' in data, "Citation event should have a 'payload'"
  502. # Check payload fields
  503. assert 'document_id' in data['payload'], "Payload should have 'document_id'"
  504. assert 'text' in data['payload'], "Payload should have 'text'"
  505. assert 'metadata' in data['payload'], "Payload should have 'metadata'"
  506. except json.JSONDecodeError:
  507. pytest.fail(f"Citation event data is not valid JSON: {data_part}")
  508. @pytest.mark.asyncio
  509. async def test_final_answer_event_format(mock_streaming_agent):
  510. """Test that the final answer event follows the expected format."""
  511. messages = [Message(role="user", content="Test query")]
  512. # Run the agent
  513. stream = mock_streaming_agent.arun(messages=messages)
  514. output = await collect_stream_output(stream)
  515. # Look for final answer event
  516. final_answer_events = [
  517. line for line in output
  518. if 'event: agent.final_answer' in line
  519. ]
  520. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  521. # Check the format of the final answer event
  522. for event in final_answer_events:
  523. assert 'event: agent.final_answer' in event, "Event type should be 'agent.final_answer'"
  524. assert 'data: ' in event, "Event should have data payload"
  525. # Parse the data payload
  526. data_part = event.split('data: ')[1] if 'data: ' in event else event
  527. try:
  528. data = json.loads(data_part)
  529. # Check required fields
  530. assert 'id' in data, "Final answer event should have an 'id'"
  531. assert 'object' in data and data['object'] == 'agent.final_answer', "Event object should be 'agent.final_answer'"
  532. assert 'generated_answer' in data, "Final answer event should have a 'generated_answer'"
  533. assert 'citations' in data, "Final answer event should have 'citations'"
  534. # Check citation fields
  535. for citation in data['citations']:
  536. assert 'id' in citation, "Citation should have an 'id'"
  537. assert 'object' in citation and citation['object'] == 'citation', "Citation object should be 'citation'"
  538. assert 'spans' in citation, "Citation should have 'spans'"
  539. assert 'payload' in citation, "Citation should have a 'payload'"
  540. # Check spans format
  541. for span in citation['spans']:
  542. assert 'start' in span, "Span should have 'start'"
  543. assert 'end' in span, "Span should have 'end'"
  544. # Check payload fields
  545. assert 'document_id' in citation['payload'], "Payload should have 'document_id'"
  546. assert 'text' in citation['payload'], "Payload should have 'text'"
  547. assert 'metadata' in citation['payload'], "Payload should have 'metadata'"
  548. except json.JSONDecodeError:
  549. pytest.fail(f"Final answer event data is not valid JSON: {data_part}")
  550. @pytest.mark.asyncio
  551. async def test_overlapping_citation_handling():
  552. """Test that overlapping citations are handled correctly."""
  553. # Create a custom agent configuration
  554. config = MagicMock()
  555. config.stream = True
  556. config.max_iterations = 3
  557. # Create providers
  558. llm_provider = MockLLMProvider(
  559. response_content="This is a test response with overlapping citations",
  560. citations=["abc1234", "def5678"]
  561. )
  562. db_provider = MockDatabaseProvider()
  563. # Create agent
  564. agent = MockR2RStreamingAgent(
  565. database_provider=db_provider,
  566. llm_provider=llm_provider,
  567. config=config,
  568. rag_generation_config=GenerationConfig(model="test/model")
  569. )
  570. # Replace the search results collector with our mock
  571. agent.search_results_collector = MockSearchResultsCollector({
  572. "abc1234": {
  573. "document_id": "doc_abc1234",
  574. "text": "This is document text for abc1234",
  575. "metadata": {"source": "source_abc1234"}
  576. },
  577. "def5678": {
  578. "document_id": "doc_def5678",
  579. "text": "This is document text for def5678",
  580. "metadata": {"source": "source_def5678"}
  581. }
  582. })
  583. # Modify the arun method for overlapping citations
  584. original_arun = agent.arun
  585. async def custom_arun(*args, **kwargs):
  586. """Custom arun that includes overlapping citations."""
  587. # Setup like the original
  588. await agent._setup(kwargs.get('system_instruction'))
  589. messages = kwargs.get('messages', [])
  590. if messages:
  591. for m in messages:
  592. await agent.conversation.add_message(m)
  593. # Initialize citation tracker
  594. citation_tracker = CitationTracker()
  595. citation_payloads = {}
  596. # Track streaming citations for final persistence
  597. agent.streaming_citations = []
  598. # Create text with overlapping citations (citation spans that overlap)
  599. response_content = "This text has overlapping citations [abc1234] part of which [def5678] overlap."
  600. # Yield the message event
  601. yield agent._format_sse_event("message", {"content": response_content})
  602. # Manually create overlapping citation spans
  603. # For simplicity, we'll define the spans directly rather than using regex
  604. citation_spans = {
  605. "abc1234": [(30, 39)], # This span includes "[abc1234]"
  606. "def5678": [(55, 64)] # This span includes "[def5678]"
  607. }
  608. # Process the citations
  609. for cid, spans in citation_spans.items():
  610. for span in spans:
  611. # Mark as processed in the tracker
  612. citation_tracker.is_new_span(cid, span)
  613. # Look up the source document for this citation
  614. source_doc = agent.search_results_collector.find_by_short_id(cid)
  615. # Create citation payload
  616. citation_payload = {
  617. "document_id": source_doc.get("document_id", f"doc_{cid}"),
  618. "text": source_doc.get("text", f"This is document text for {cid}"),
  619. "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
  620. }
  621. # Store the payload by citation ID
  622. citation_payloads[cid] = citation_payload
  623. # Track for persistence
  624. agent.streaming_citations.append({
  625. "id": cid,
  626. "span": {"start": span[0], "end": span[1]},
  627. "payload": citation_payload
  628. })
  629. # Emit citation event
  630. citation_event = {
  631. "id": cid,
  632. "object": "citation",
  633. "span": {"start": span[0], "end": span[1]},
  634. "payload": citation_payload
  635. }
  636. yield agent._format_sse_event("citation", citation_event)
  637. # Add assistant message with citation metadata to conversation
  638. await agent.conversation.add_message(
  639. Message(
  640. role="assistant",
  641. content=response_content,
  642. metadata={"citations": agent.streaming_citations}
  643. )
  644. )
  645. # Prepare consolidated citations for final answer
  646. consolidated_citations = []
  647. # Group citations by ID with all their spans
  648. for cid, spans in citation_tracker.get_all_spans().items():
  649. if cid in citation_payloads:
  650. consolidated_citations.append({
  651. "id": cid,
  652. "object": "citation",
  653. "spans": [{"start": s[0], "end": s[1]} for s in spans],
  654. "payload": citation_payloads[cid]
  655. })
  656. # Create and emit final answer event
  657. final_evt_payload = {
  658. "id": "msg_final",
  659. "object": "agent.final_answer",
  660. "generated_answer": response_content,
  661. "citations": consolidated_citations
  662. }
  663. # Emit final answer event
  664. yield agent._format_sse_event("agent.final_answer", final_evt_payload)
  665. # Signal the end of the SSE stream
  666. yield "event: done\ndata: {}\n\n"
  667. # Replace the arun method
  668. with patch.object(agent, 'arun', custom_arun):
  669. messages = [Message(role="user", content="Test query")]
  670. # Run the agent with overlapping citations
  671. stream = agent.arun(messages=messages)
  672. output = await collect_stream_output(stream)
  673. # Check that both citations were emitted
  674. citation_abc = any('abc1234' in event for event in output if 'event: citation' in event)
  675. citation_def = any('def5678' in event for event in output if 'event: citation' in event)
  676. assert citation_abc, "Citation abc1234 should be emitted"
  677. assert citation_def, "Citation def5678 should be emitted"
  678. # Check the final answer for both citations
  679. final_answer_events = [
  680. line for line in output
  681. if 'event: agent.final_answer' in line
  682. ]
  683. for event in final_answer_events:
  684. data_part = event.split('data: ')[1] if 'data: ' in event else event
  685. try:
  686. data = json.loads(data_part)
  687. if 'citations' in data:
  688. citation_ids = [citation.get('id') for citation in data['citations']]
  689. assert 'abc1234' in citation_ids, "abc1234 should be in final answer citations"
  690. assert 'def5678' in citation_ids, "def5678 should be in final answer citations"
  691. except json.JSONDecodeError:
  692. continue
  693. @pytest.mark.asyncio
  694. async def test_robustness_against_citation_variations(mock_streaming_agent):
  695. """Test agent's robustness against different citation formats and variations."""
  696. # Create a custom text with different citation variations
  697. response_text = """
  698. This text has different citation variations:
  699. 1. Standard citation: [abc1234]
  700. 2. Another citation: [def5678]
  701. 3. Adjacent citations: [abc1234][def5678]
  702. 4. Special characters around citation: ([abc1234]) or "[def5678]".
  703. """
  704. # Use the extract_citations function directly to see what would be detected
  705. citations = extract_citations(response_text)
  706. # There should be at least two different citation IDs
  707. unique_citations = set(citations)
  708. assert len(unique_citations) >= 2, "Should extract at least two different citation IDs"
  709. assert "abc1234" in unique_citations, "Should extract abc1234"
  710. assert "def5678" in unique_citations, "Should extract def5678"
  711. # Count occurrences of each citation
  712. counts = {}
  713. for cid in citations:
  714. counts[cid] = counts.get(cid, 0) + 1
  715. # Each citation should be found the correct number of times based on the text
  716. assert counts.get("abc1234", 0) >= 2, "abc1234 should appear at least twice"
  717. assert counts.get("def5678", 0) >= 2, "def5678 should appear at least twice"
  718. class TestCitationEdgeCases:
  719. """
  720. Test class for citation edge cases using parameterized tests to cover multiple scenarios.
  721. """
  722. @pytest.mark.parametrize("test_case", [
  723. # Test case 1: Empty text
  724. {"text": "", "expected_citations": []},
  725. # Test case 2: Text with no citations
  726. {"text": "This text has no citations.", "expected_citations": []},
  727. # Test case 3: Adjacent citations
  728. {"text": "Adjacent citations [abc1234][def5678]", "expected_citations": ["abc1234", "def5678"]},
  729. # Test case 4: Repeated citations
  730. {"text": "Repeated [abc1234] citation [abc1234]", "expected_citations": ["abc1234", "abc1234"]},
  731. # Test case 5: Citation at beginning
  732. {"text": "[abc1234] at beginning", "expected_citations": ["abc1234"]},
  733. # Test case 6: Citation at end
  734. {"text": "At end [abc1234]", "expected_citations": ["abc1234"]},
  735. # Test case 7: Mixed valid and invalid citations
  736. {"text": "Valid [abc1234] and invalid [ab123] citations", "expected_citations": ["abc1234"]},
  737. # Test case 8: Citations with punctuation
  738. {"text": "Citations with punctuation: ([abc1234]), [def5678]!", "expected_citations": ["abc1234", "def5678"]}
  739. ])
  740. def test_citation_extraction_cases(self, test_case):
  741. """Test citation extraction with various edge cases."""
  742. text = test_case["text"]
  743. expected = test_case["expected_citations"]
  744. # Extract citations
  745. actual = extract_citations(text)
  746. # Check count
  747. assert len(actual) == len(expected), f"Expected {len(expected)} citations, got {len(actual)}"
  748. # Check content (allowing for different orders)
  749. if expected:
  750. for expected_citation in expected:
  751. assert expected_citation in actual, f"Expected citation {expected_citation} not found"
  752. @pytest.mark.asyncio
  753. async def test_citation_handling_with_empty_response():
  754. """Test how the agent handles responses with no citations."""
  755. # Create a custom R2RStreamingAgent with no citations
  756. # Custom agent class for testing empty citations
  757. class EmptyResponseAgent(MockR2RStreamingAgent):
  758. async def arun(
  759. self,
  760. system_instruction: str = None,
  761. messages: list[Message] = None,
  762. *args,
  763. **kwargs,
  764. ) -> AsyncGenerator[str, None]:
  765. """Custom arun with no citations in the response."""
  766. await self._setup(system_instruction)
  767. if messages:
  768. for m in messages:
  769. await self.conversation.add_message(m)
  770. # Initialize citation tracker
  771. citation_tracker = CitationTracker()
  772. # Empty response with no citations
  773. response_content = "This is a response with no citations."
  774. # Yield an initial message event with the start of the text
  775. yield self._format_sse_event("message", {"content": response_content})
  776. # No citation spans to extract
  777. citation_spans = extract_citation_spans(response_content)
  778. # Should be empty
  779. assert len(citation_spans) == 0, "No citation spans should be found"
  780. # Add assistant message to conversation (with no citation metadata)
  781. await self.conversation.add_message(
  782. Message(
  783. role="assistant",
  784. content=response_content,
  785. metadata={"citations": []}
  786. )
  787. )
  788. # Create and emit final answer event
  789. final_evt_payload = {
  790. "id": "msg_final",
  791. "object": "agent.final_answer",
  792. "generated_answer": response_content,
  793. "citations": []
  794. }
  795. yield self._format_sse_event("agent.final_answer", final_evt_payload)
  796. yield "event: done\ndata: {}\n\n"
  797. # Create the agent with empty citation response
  798. config = MagicMock()
  799. config.stream = True
  800. llm_provider = MockLLMProvider(
  801. response_content="This is a response with no citations.",
  802. citations=[]
  803. )
  804. db_provider = MockDatabaseProvider()
  805. # Create the custom agent
  806. agent = EmptyResponseAgent(
  807. database_provider=db_provider,
  808. llm_provider=llm_provider,
  809. config=config,
  810. rag_generation_config=GenerationConfig(model="test/model")
  811. )
  812. # Test a simple query
  813. messages = [Message(role="user", content="Query with no citations")]
  814. # Run the agent
  815. stream = agent.arun(messages=messages)
  816. output = await collect_stream_output(stream)
  817. # Verify no citation events were emitted
  818. citation_events = [line for line in output if 'event: citation' in line]
  819. assert len(citation_events) == 0, "No citation events should be emitted"
  820. # Parse the final answer event to check citations
  821. final_answer_events = [line for line in output if 'event: agent.final_answer' in line]
  822. assert len(final_answer_events) > 0, "Final answer event should be emitted"
  823. data_part = final_answer_events[0].split('data: ')[1] if 'data: ' in final_answer_events[0] else ""
  824. # Parse final answer data
  825. try:
  826. data = json.loads(data_part)
  827. assert 'citations' in data, "Final answer event should include citations field"
  828. assert len(data['citations']) == 0, "Citations list should be empty"
  829. except json.JSONDecodeError:
  830. assert False, "Final answer event data should be valid JSON"
  831. @pytest.mark.asyncio
  832. async def test_citation_sanitization():
  833. """Test that citation IDs are properly sanitized before processing."""
  834. # Since extract_citations uses a strict regex pattern [A-Za-z0-9]{7,8},
  835. # we should test with valid citation formats
  836. text = "Citation with surrounding text[abc1234]and [def5678]with no spaces."
  837. # Extract citations
  838. citations = extract_citations(text)
  839. # Check if citations are properly extracted
  840. assert "abc1234" in citations, "Citation abc1234 should be extracted"
  841. assert "def5678" in citations, "Citation def5678 should be extracted"
  842. # Test with spaces - these should NOT be extracted based on the implementation
  843. text_with_spaces = "Citation with [abc1234 ] and [ def5678] spaces."
  844. citations_with_spaces = extract_citations(text_with_spaces)
  845. # The current implementation doesn't extract citations with spaces inside the brackets
  846. assert len(citations_with_spaces) == 0 or "abc1234" not in citations_with_spaces, "Citations with spaces should not be extracted with current implementation"
  847. @pytest.mark.asyncio
  848. async def test_citation_tracking_state_persistence():
  849. """Test that the CitationTracker correctly maintains state across multiple calls."""
  850. tracker = CitationTracker()
  851. # Record some initial spans
  852. tracker.is_new_span("abc1234", (10, 18))
  853. tracker.is_new_span("def5678", (30, 38))
  854. # Check if spans are correctly stored
  855. all_spans = tracker.get_all_spans()
  856. assert "abc1234" in all_spans, "Citation abc1234 should be tracked"
  857. assert "def5678" in all_spans, "Citation def5678 should be tracked"
  858. assert all_spans["abc1234"] == [(10, 18)], "Span positions should match"
  859. # Add another span for an existing citation
  860. tracker.is_new_span("abc1234", (50, 58))
  861. # Check if the new span was added
  862. all_spans = tracker.get_all_spans()
  863. assert len(all_spans["abc1234"]) == 2, "Citation abc1234 should have 2 spans"
  864. assert (50, 58) in all_spans["abc1234"], "New span should be added"
  865. def test_citation_span_uniqueness():
  866. """Test that CitationTracker correctly identifies duplicate spans."""
  867. tracker = CitationTracker()
  868. # Record a span
  869. tracker.is_new_span("abc1234", (10, 18))
  870. # Check if the same span is recognized as not new
  871. assert not tracker.is_new_span("abc1234", (10, 18)), "Duplicate span should not be considered new"
  872. # Check if different span for same citation is recognized as new
  873. assert tracker.is_new_span("abc1234", (20, 28)), "Different span should be considered new"
  874. # Check if same span for different citation is recognized as new
  875. assert tracker.is_new_span("def5678", (10, 18)), "Same span for different citation should be considered new"
  876. def test_citation_with_punctuation():
  877. """Test extraction of citations with surrounding punctuation."""
  878. text = "Citations with punctuation: ([abc1234]), [def5678]!, and [ghi9012]."
  879. # Extract citations
  880. citations = extract_citations(text)
  881. # Check if all citations are extracted correctly
  882. assert "abc1234" in citations, "Citation abc1234 should be extracted"
  883. assert "def5678" in citations, "Citation def5678 should be extracted"
  884. assert "ghi9012" in citations, "Citation ghi9012 should be extracted"
  885. def test_citation_extraction_with_invalid_formats():
  886. """Test that invalid citation formats are not extracted."""
  887. text = "Invalid citation formats: [123], [abcdef], [abc123456789], and valid [abc1234]."
  888. # Extract citations
  889. citations = extract_citations(text)
  890. # Check that only valid citations are extracted
  891. assert len(citations) == 1, "Only one valid citation should be extracted"
  892. assert "abc1234" in citations, "Only valid citation abc1234 should be extracted"
  893. assert "123" not in citations, "Invalid citation [123] should not be extracted"
  894. assert "abcdef" not in citations, "Invalid citation [abcdef] should not be extracted"
  895. assert "abc123456789" not in citations, "Invalid citation [abc123456789] should not be extracted"