test_citations.py 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010
  1. """
  2. Unit tests for citation handling in retrieval functionality.
  3. """
  4. import pytest
  5. import re
  6. from unittest.mock import AsyncMock, MagicMock, patch
  7. from typing import Dict, List, Any, Optional
  8. # Import citation utilities from core.utils
  9. from core.utils import (
  10. extract_citations,
  11. extract_citation_spans,
  12. find_new_citation_spans,
  13. CitationTracker as CoreCitationTracker
  14. )
  15. class CitationTracker:
  16. """Simple citation tracker for testing."""
  17. def __init__(self):
  18. # Track which citation spans we've processed
  19. # Format: {citation_id: {(start, end), (start, end), ...}}
  20. self.processed_spans = {}
  21. self.citation_spans = {}
  22. def is_new_span(self, citation_id, span):
  23. """Check if this span is new and mark it as processed if it is."""
  24. # Handle invalid inputs
  25. if citation_id is None or citation_id == "" or span is None:
  26. return False
  27. # Initialize set for this citation ID if needed
  28. if citation_id not in self.processed_spans:
  29. self.processed_spans[citation_id] = set()
  30. # Check if we've seen this span before for this citation
  31. if span in self.processed_spans[citation_id]:
  32. return False
  33. # This is a new span, track it
  34. self.processed_spans[citation_id].add(span)
  35. # Also track by citation ID for easy lookup
  36. if citation_id not in self.citation_spans:
  37. self.citation_spans[citation_id] = []
  38. self.citation_spans[citation_id].append(span)
  39. return True
  40. def get_all_citation_spans(self):
  41. """Get all citation spans processed so far."""
  42. return {
  43. citation_id: spans
  44. for citation_id, spans in self.citation_spans.items()
  45. }
  46. class MockCitation:
  47. """Mock Citation class for testing."""
  48. def __init__(self, citation_id, chunk_id=None, document_id=None, text=None, metadata=None):
  49. self.citation_id = citation_id
  50. self.chunk_id = chunk_id or f"chunk-{citation_id}"
  51. self.document_id = document_id or f"doc-{citation_id}"
  52. self.text = text or f"Citation text for {citation_id}"
  53. self.metadata = metadata or {"source": f"source-{citation_id}"}
  54. self.spans = []
  55. @pytest.fixture
  56. def mock_providers():
  57. """Return a mocked providers object for testing."""
  58. class MockProviders:
  59. def __init__(self):
  60. # Mock the database
  61. self.database = AsyncMock()
  62. self.database.citations_handler = AsyncMock()
  63. self.database.citations_handler.get_citation = AsyncMock(
  64. side_effect=lambda citation_id: MockCitation(citation_id)
  65. )
  66. # Mock LLM
  67. self.llm = AsyncMock()
  68. self.llm.aget_completion = AsyncMock(
  69. return_value={"choices": [{"message": {"content": "Response with [abc1234] citation"}}]}
  70. )
  71. self.llm.aget_completion_stream = AsyncMock(
  72. return_value=iter([
  73. {"choices": [{"delta": {"content": "Response "}}]},
  74. {"choices": [{"delta": {"content": "with "}}]},
  75. {"choices": [{"delta": {"content": "[abc1234] "}}]},
  76. {"choices": [{"delta": {"content": "citation"}}]}
  77. ])
  78. )
  79. return MockProviders()
  80. @pytest.fixture
  81. def sample_chunk_results():
  82. """Return sample chunk results with citation metadata."""
  83. return [
  84. {
  85. "chunk_id": f"chunk-{i}",
  86. "document_id": f"doc-{i//2}",
  87. "text": f"This is chunk {i} with information about the topic.",
  88. "metadata": {
  89. "source": f"source-{i}",
  90. "citation_id": f"cite{i}"
  91. },
  92. "score": 0.95 - (i * 0.05),
  93. }
  94. for i in range(5)
  95. ]
  96. class TestCitationExtraction:
  97. """Tests for citation extraction functionality."""
  98. def test_extract_citations_basic(self):
  99. """Test basic citation extraction from text with standard format."""
  100. # Test function to extract citations
  101. def extract_citations(text):
  102. citation_pattern = r'\[([\w\d]+)\]'
  103. citations = re.findall(citation_pattern, text)
  104. return citations
  105. # Test cases
  106. test_cases = [
  107. (
  108. "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234].",
  109. ["abc1234"]
  110. ),
  111. (
  112. "According to Plato [xyz5678] and Aristotle [abc1234], philosophy is important.",
  113. ["xyz5678", "abc1234"]
  114. ),
  115. (
  116. "This text has no citations.",
  117. []
  118. ),
  119. (
  120. "Multiple citations in a row [abc1234][def5678][ghi9012] should all be found.",
  121. ["abc1234", "def5678", "ghi9012"]
  122. )
  123. ]
  124. # Run tests
  125. for text, expected_citations in test_cases:
  126. extracted = extract_citations(text)
  127. assert extracted == expected_citations
  128. def test_extract_citations_with_spans(self):
  129. """Test citation extraction with text spans."""
  130. # Test function to extract citations with spans
  131. def extract_citations_with_spans(text):
  132. citation_pattern = r'\[([\w\d]+)\]'
  133. citations_with_spans = []
  134. for match in re.finditer(citation_pattern, text):
  135. citation_id = match.group(1)
  136. start = match.start()
  137. end = match.end()
  138. # Get the context (text before and after the citation)
  139. context_start = max(0, start - 50)
  140. context_end = min(len(text), end + 50)
  141. context = text[context_start:context_end]
  142. citations_with_spans.append({
  143. "citation_id": citation_id,
  144. "start": start,
  145. "end": end,
  146. "context": context
  147. })
  148. return citations_with_spans
  149. # Test text
  150. text = (
  151. "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234]. "
  152. "According to Plato [xyz5678], the ideal state is described in The Republic. "
  153. "Socrates' method of questioning is demonstrated in many dialogues [ghi9012]."
  154. )
  155. # Extract citations with spans
  156. extracted = extract_citations_with_spans(text)
  157. # Verify the correct number of citations was extracted
  158. assert len(extracted) == 3
  159. # Verify citation IDs are correct
  160. assert extracted[0]["citation_id"] == "abc1234"
  161. assert extracted[1]["citation_id"] == "xyz5678"
  162. assert extracted[2]["citation_id"] == "ghi9012"
  163. # Verify spans and context
  164. for citation in extracted:
  165. assert citation["start"] < citation["end"]
  166. assert text[citation["start"]:citation["end"]] == f"[{citation['citation_id']}]"
  167. assert citation["citation_id"] in citation["context"]
  168. def test_citation_extraction_edge_cases(self):
  169. """Test citation extraction with edge cases and malformed citations."""
  170. # Test function to extract citations that exactly matches the implementation in core.utils
  171. def extract_citations(text):
  172. # Handle None or empty input
  173. if text is None or text == "":
  174. return []
  175. # Match the core implementation pattern: 7-8 alphanumeric chars
  176. citation_pattern = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
  177. sids = []
  178. for match in citation_pattern.finditer(text):
  179. sid = match.group(1)
  180. sids.append(sid)
  181. return sids
  182. # Edge case tests
  183. test_cases = [
  184. (
  185. "Incomplete citation [abc1234", # Missing closing bracket
  186. [] # This would not match with the regular pattern
  187. ),
  188. (
  189. "Empty citation []", # Empty citation
  190. [] # This would match but capture an empty string
  191. ),
  192. (
  193. "Citation with special chars [abc-1234]", # Contains hyphen
  194. [] # Should not capture because hyphen is not allowed in the pattern
  195. ),
  196. (
  197. "Citation at the end of sentence[abcd1234].", # No space before citation
  198. ["abcd1234"] # Should still capture
  199. ),
  200. (
  201. "Valid citation [abc1234]", # Valid citation
  202. ["abc1234"] # Should capture
  203. ),
  204. (
  205. "Text with [short] but no valid citation format.", # 'short' is only 5 chars, too short
  206. [] # Should not extract non-citation brackets with wrong length
  207. ),
  208. (
  209. "Text with [abc123] (too short) and [abcdefghi] (too long).",
  210. [] # Should not extract brackets with wrong length
  211. ),
  212. (
  213. "Text with [abc-1234] has the right length but contains special characters.",
  214. [] # Should not extract brackets with special characters
  215. ),
  216. ]
  217. # Run tests
  218. for text, expected_citations in test_cases:
  219. extracted = extract_citations(text)
  220. assert extracted == expected_citations
  221. def test_citation_sanitization(self):
  222. """Test sanitization of citation IDs."""
  223. # Function to sanitize citation IDs
  224. def sanitize_citation_id(citation_id):
  225. # Remove any non-alphanumeric characters
  226. return re.sub(r'[^a-zA-Z0-9]', '', citation_id)
  227. # Test cases
  228. test_cases = [
  229. ("abc1234", "abc1234"), # Already clean
  230. ("abc-1234", "abc1234"), # Contains hyphen
  231. ("abc.1234", "abc1234"), # Contains period
  232. ("abc_1234", "abc1234"), # Contains underscore
  233. ("abc 1234", "abc1234"), # Contains space
  234. ]
  235. # Run tests
  236. for input_id, expected_id in test_cases:
  237. sanitized = sanitize_citation_id(input_id)
  238. assert sanitized == expected_id
  239. class TestCitationTracker:
  240. """Tests for citation tracking functionality."""
  241. def test_citation_tracker_init(self):
  242. """Test initialization of citation tracker."""
  243. tracker = CitationTracker()
  244. assert hasattr(tracker, 'processed_spans')
  245. assert hasattr(tracker, 'citation_spans')
  246. assert isinstance(tracker.processed_spans, dict)
  247. assert isinstance(tracker.citation_spans, dict)
  248. assert len(tracker.processed_spans) == 0
  249. assert len(tracker.citation_spans) == 0
  250. def test_is_new_span(self):
  251. """Test is_new_span method."""
  252. tracker = CitationTracker()
  253. # First occurrence should be new
  254. assert tracker.is_new_span("abc1234", (10, 18)) is True
  255. # Same span should not be new anymore
  256. assert tracker.is_new_span("abc1234", (10, 18)) is False
  257. # Different span for same citation should be new
  258. assert tracker.is_new_span("abc1234", (30, 38)) is True
  259. # Different citation ID should be new
  260. assert tracker.is_new_span("def5678", (10, 18)) is True
  261. def test_get_all_citation_spans(self):
  262. """Test get_all_citation_spans method."""
  263. tracker = CitationTracker()
  264. # Add some spans
  265. tracker.is_new_span("abc1234", (10, 18))
  266. tracker.is_new_span("abc1234", (30, 38))
  267. tracker.is_new_span("def5678", (50, 58))
  268. # Get all spans
  269. all_spans = tracker.get_all_citation_spans()
  270. # Verify results
  271. assert "abc1234" in all_spans
  272. assert "def5678" in all_spans
  273. assert len(all_spans["abc1234"]) == 2
  274. assert len(all_spans["def5678"]) == 1
  275. assert (10, 18) in all_spans["abc1234"]
  276. assert (30, 38) in all_spans["abc1234"]
  277. assert (50, 58) in all_spans["def5678"]
  278. def test_citation_tracker_multiple_spans(self):
  279. """Test tracking multiple citation spans."""
  280. tracker = CitationTracker()
  281. # Sample text with multiple citations
  282. text = (
  283. "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234]. "
  284. "Later in the same work [abc1234], he expanded on this concept. "
  285. "According to Plato [def5678], the ideal state is described in The Republic."
  286. )
  287. # Extract and track citations
  288. citation_pattern = r'\[([\w\d]+)\]'
  289. for match in re.finditer(citation_pattern, text):
  290. citation_id = match.group(1)
  291. start = match.start()
  292. end = match.end()
  293. tracker.is_new_span(citation_id, (start, end))
  294. # Verify tracking
  295. all_spans = tracker.get_all_citation_spans()
  296. assert len(all_spans["abc1234"]) == 2
  297. assert len(all_spans["def5678"]) == 1
  298. class TestCitationStreamingEvents:
  299. """Tests for citation events during streaming."""
  300. def test_emit_citation_event(self):
  301. """Test emitting a citation event during streaming."""
  302. # Create a mock agent
  303. class MockAgent:
  304. def __init__(self):
  305. self.emitted_events = []
  306. def emit_event(self, event):
  307. self.emitted_events.append(event)
  308. agent = MockAgent()
  309. # Function to emit a citation event
  310. def emit_citation_event(agent, citation_id, start, end, text_context):
  311. event = {
  312. "type": "citation",
  313. "data": {
  314. "citation_id": citation_id,
  315. "start": start,
  316. "end": end,
  317. "text_context": text_context
  318. }
  319. }
  320. agent.emit_event(event)
  321. # Emit an event
  322. emit_citation_event(agent, "abc1234", 10, 18, "text with [abc1234] citation")
  323. # Verify event
  324. assert len(agent.emitted_events) == 1
  325. event = agent.emitted_events[0]
  326. assert event["type"] == "citation"
  327. assert event["data"]["citation_id"] == "abc1234"
  328. assert event["data"]["start"] == 10
  329. assert event["data"]["end"] == 18
  330. def test_citation_tracking_during_streaming(self):
  331. """Test tracking citations during streaming."""
  332. # Create a mock agent with citation tracker
  333. class MockAgent:
  334. def __init__(self):
  335. self.emitted_events = []
  336. self.citation_tracker = CitationTracker()
  337. def emit_event(self, event):
  338. self.emitted_events.append(event)
  339. agent = MockAgent()
  340. # Function to process streaming text and emit citation events
  341. def process_streaming_text(agent, text, start_offset=0):
  342. # Extract citations
  343. citation_pattern = r'\[([\w\d]+)\]'
  344. for match in re.finditer(citation_pattern, text):
  345. citation_id = match.group(1)
  346. start = match.start() + start_offset
  347. end = match.end() + start_offset
  348. # Check if this is a new span
  349. if agent.citation_tracker.is_new_span(citation_id, (start, end)):
  350. # Get context
  351. context_start = max(0, match.start() - 10)
  352. context_end = min(len(text), match.end() + 10)
  353. context = text[context_start:context_end]
  354. # Emit event
  355. event = {
  356. "type": "citation",
  357. "data": {
  358. "citation_id": citation_id,
  359. "start": start,
  360. "end": end,
  361. "text_context": context
  362. }
  363. }
  364. agent.emit_event(event)
  365. # Process streaming text in chunks
  366. chunks = [
  367. "Aristotle discussed virtue ethics ",
  368. "in his Nicomachean Ethics [abc1234]. ",
  369. "According to Plato [def5678], ",
  370. "the ideal state is described in The Republic. ",
  371. "Later, Aristotle also mentioned [abc1234] this concept."
  372. ]
  373. offset = 0
  374. for chunk in chunks:
  375. process_streaming_text(agent, chunk, offset)
  376. offset += len(chunk)
  377. # Verify events and tracking
  378. assert len(agent.emitted_events) == 3 # 3 citations total (2 abc1234, 1 def5678)
  379. # Verify citation IDs in events
  380. citation_ids = [event["data"]["citation_id"] for event in agent.emitted_events]
  381. assert citation_ids.count("abc1234") == 2
  382. assert citation_ids.count("def5678") == 1
  383. # Verify tracker state
  384. all_spans = agent.citation_tracker.get_all_citation_spans()
  385. assert len(all_spans["abc1234"]) == 2
  386. assert len(all_spans["def5678"]) == 1
  387. class TestRAGWithCitations:
  388. """Tests for RAG functionality with citations."""
  389. @pytest.mark.asyncio
  390. async def test_rag_with_citation_metadata(self, mock_providers, sample_chunk_results):
  391. """Test RAG with citation metadata in search results."""
  392. # Function to build a RAG prompt with citations
  393. def build_rag_prompt_with_citations(query, search_results):
  394. context = ""
  395. citation_metadata = {}
  396. for i, result in enumerate(search_results):
  397. # Extract citation information
  398. citation_id = result.get("metadata", {}).get("citation_id")
  399. if citation_id:
  400. # Add to context with citation marker
  401. context += f"\n[{i+1}] {result['text']} [{citation_id}]"
  402. # Store metadata
  403. citation_metadata[citation_id] = {
  404. "document_id": result["document_id"],
  405. "chunk_id": result["chunk_id"],
  406. "metadata": result.get("metadata", {})
  407. }
  408. else:
  409. context += f"\n[{i+1}] {result['text']}"
  410. prompt = f"Question: {query}\n\nContext:{context}\n\nPlease answer the question based on the provided context."
  411. return prompt, citation_metadata
  412. # Build prompt
  413. query = "What is the main concept?"
  414. prompt, citation_metadata = build_rag_prompt_with_citations(query, sample_chunk_results)
  415. # Verify prompt contains citations
  416. for i in range(5):
  417. assert f"[cite{i}]" in prompt
  418. # Verify metadata is stored
  419. assert len(citation_metadata) == 5
  420. for i in range(5):
  421. assert f"cite{i}" in citation_metadata
  422. assert citation_metadata[f"cite{i}"]["document_id"] == f"doc-{i//2}"
  423. assert citation_metadata[f"cite{i}"]["chunk_id"] == f"chunk-{i}"
  424. @pytest.mark.asyncio
  425. async def test_rag_response_with_citations(self, mock_providers, sample_chunk_results):
  426. """Test generating a RAG response with citations."""
  427. # Function to generate RAG response with citations
  428. async def generate_rag_response_with_citations(query, search_results):
  429. # Build prompt with citations
  430. context = ""
  431. citation_metadata = {}
  432. for i, result in enumerate(search_results):
  433. citation_id = result.get("metadata", {}).get("citation_id")
  434. if citation_id:
  435. context += f"\n[{i+1}] {result['text']} [{citation_id}]"
  436. citation_metadata[citation_id] = {
  437. "document_id": result["document_id"],
  438. "chunk_id": result["chunk_id"],
  439. "metadata": result.get("metadata", {})
  440. }
  441. else:
  442. context += f"\n[{i+1}] {result['text']}"
  443. prompt = f"Question: {query}\n\nContext:{context}\n\nPlease answer the question based on the provided context."
  444. # Generate response (mocked)
  445. # In real implementation, this would call the LLM
  446. mock_providers.llm.aget_completion.return_value = {
  447. "choices": [{
  448. "message": {
  449. "content": "The main concept is explained in [cite0] and further elaborated in [cite2]."
  450. }
  451. }]
  452. }
  453. response = await mock_providers.llm.aget_completion(prompt=prompt)
  454. content = response["choices"][0]["message"]["content"]
  455. return content, citation_metadata
  456. # Generate response
  457. query = "What is the main concept?"
  458. response, citation_metadata = await generate_rag_response_with_citations(query, sample_chunk_results)
  459. # Verify response contains citations
  460. assert "[cite0]" in response
  461. assert "[cite2]" in response
  462. # Extract citations from response
  463. def extract_citations_from_response(text):
  464. citation_pattern = r'\[([\w\d]+)\]'
  465. citations = re.findall(citation_pattern, text)
  466. return citations
  467. citations = extract_citations_from_response(response)
  468. assert "cite0" in citations
  469. assert "cite2" in citations
  470. @pytest.mark.asyncio
  471. async def test_consolidate_citations_in_final_answer(self, mock_providers):
  472. """Test consolidating citations in the final answer."""
  473. # Create a citation tracker with some spans
  474. tracker = CitationTracker()
  475. tracker.is_new_span("cite0", (10, 18))
  476. tracker.is_new_span("cite0", (30, 38))
  477. tracker.is_new_span("cite2", (50, 58))
  478. # Create citation metadata
  479. citation_metadata = {
  480. "cite0": {
  481. "document_id": "doc-0",
  482. "chunk_id": "chunk-0",
  483. "metadata": {"source": "source-0", "title": "Document 0"}
  484. },
  485. "cite2": {
  486. "document_id": "doc-1",
  487. "chunk_id": "chunk-2",
  488. "metadata": {"source": "source-2", "title": "Document 1"}
  489. }
  490. }
  491. # Function to consolidate citations
  492. def consolidate_citations(response_text, citation_tracker, citation_metadata):
  493. # Get all citations from the tracker
  494. all_citation_spans = citation_tracker.get_all_citation_spans()
  495. # Build consolidated citations
  496. consolidated_citations = {}
  497. for citation_id, spans in all_citation_spans.items():
  498. if citation_id in citation_metadata:
  499. metadata = citation_metadata[citation_id]
  500. consolidated_citations[citation_id] = {
  501. "spans": spans,
  502. "document_id": metadata["document_id"],
  503. "chunk_id": metadata["chunk_id"],
  504. "metadata": metadata["metadata"]
  505. }
  506. # Return the response with consolidated citations
  507. return {
  508. "response": response_text,
  509. "citations": consolidated_citations
  510. }
  511. # Test response
  512. response_text = "The main concept is explained in [cite0] and further elaborated in [cite2]."
  513. # Consolidate citations
  514. result = consolidate_citations(response_text, tracker, citation_metadata)
  515. # Verify result
  516. assert "response" in result
  517. assert "citations" in result
  518. assert result["response"] == response_text
  519. # Verify consolidated citations
  520. assert "cite0" in result["citations"]
  521. assert "cite2" in result["citations"]
  522. assert len(result["citations"]["cite0"]["spans"]) == 2
  523. assert len(result["citations"]["cite2"]["spans"]) == 1
  524. assert result["citations"]["cite0"]["document_id"] == "doc-0"
  525. assert result["citations"]["cite2"]["document_id"] == "doc-1"
  526. class TestCitationUtils:
  527. """Tests for citation utility functions."""
  528. def test_extract_citations(self):
  529. """Test that citations are correctly extracted from text."""
  530. # Simple case with one citation
  531. text = "This is a test with a citation [abc1234]."
  532. citations = extract_citations(text)
  533. assert citations == ["abc1234"], "Should extract a single citation ID"
  534. # Multiple citations
  535. text = "First citation [abc1234] and second citation [def5678]."
  536. citations = extract_citations(text)
  537. assert citations == ["abc1234", "def5678"], "Should extract multiple citation IDs"
  538. # Repeated citations
  539. text = "Same citation twice [abc1234] and again [abc1234]."
  540. citations = extract_citations(text)
  541. assert len(citations) == 2, "Should extract duplicate citation IDs"
  542. assert citations == ["abc1234", "abc1234"], "Should preserve order of citations"
  543. def test_extract_citations_edge_cases(self):
  544. """Test edge cases for citation extraction."""
  545. # Define local extract_citations for testing that follows the core implementation
  546. def local_extract_citations(text):
  547. # Handle None or empty input
  548. if text is None or text == "":
  549. return []
  550. # Match the core implementation pattern: 7-8 alphanumeric chars
  551. citation_pattern = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
  552. sids = []
  553. for match in citation_pattern.finditer(text):
  554. sid = match.group(1)
  555. sids.append(sid)
  556. return sids
  557. # Citations at beginning or end of text
  558. text = "[abc1234] at the beginning and at the end [def5678]"
  559. citations = local_extract_citations(text)
  560. assert citations == ["abc1234", "def5678"], "Should extract citations at beginning and end"
  561. # Empty text
  562. text = ""
  563. citations = local_extract_citations(text)
  564. assert citations == [], "Should handle empty text gracefully"
  565. # None input
  566. citations = local_extract_citations(None)
  567. assert citations == [], "Should handle None input gracefully"
  568. # Text with brackets but no valid citation format
  569. text = "Text with [short] but no valid citation format."
  570. citations = local_extract_citations(text)
  571. assert citations == [], "Should not extract non-citation brackets (too short)"
  572. # Text with brackets but wrong length
  573. text = "Text with [abc123] (too short) and [abcdefghi] (too long)."
  574. citations = local_extract_citations(text)
  575. assert citations == [], "Should not extract brackets with wrong length"
  576. # Text with brackets that have correct length but non-alphanumeric chars
  577. text = "Text with [abc-1234] has the right length but contains special characters."
  578. citations = local_extract_citations(text)
  579. assert citations == [], "Should not extract brackets with special characters"
  580. # Text with close brackets only
  581. text = "Text with close brackets only]."
  582. citations = local_extract_citations(text)
  583. assert citations == [], "Should not extract when only close brackets present"
  584. def test_extract_citation_spans(self):
  585. """Test that citation spans are correctly extracted with positions."""
  586. # Simple case with one citation
  587. text = "This is a test with a citation [abc1234]."
  588. spans = extract_citation_spans(text)
  589. assert len(spans) == 1, "Should extract one citation ID"
  590. assert "abc1234" in spans, "Citation ID should be a key in the dictionary"
  591. assert len(spans["abc1234"]) == 1, "Should have one span for this citation"
  592. start, end = spans["abc1234"][0]
  593. assert text[start:end] == "[abc1234]", "Span positions should be correct"
  594. # Multiple citations
  595. text = "First citation [abc1234] and second citation [def5678]."
  596. spans = extract_citation_spans(text)
  597. assert len(spans) == 2, "Should extract two citation IDs"
  598. assert "abc1234" in spans, "First citation ID should be present"
  599. assert "def5678" in spans, "Second citation ID should be present"
  600. assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
  601. assert len(spans["def5678"]) == 1, "Should have one span for second citation"
  602. start1, end1 = spans["abc1234"][0]
  603. start2, end2 = spans["def5678"][0]
  604. assert text[start1:end1] == "[abc1234]", "First span positions should be correct"
  605. assert text[start2:end2] == "[def5678]", "Second span positions should be correct"
  606. def test_extract_citation_spans_edge_cases(self):
  607. """Test edge cases for citation span extraction."""
  608. # Citations at beginning or end of text
  609. text = "[abc1234] at the beginning and at the end [def5678]"
  610. spans = extract_citation_spans(text)
  611. assert len(spans) == 2, "Should extract two spans"
  612. assert "abc1234" in spans, "First citation ID should be present"
  613. assert "def5678" in spans, "Second citation ID should be present"
  614. assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
  615. assert len(spans["def5678"]) == 1, "Should have one span for second citation"
  616. start1, end1 = spans["abc1234"][0]
  617. start2, end2 = spans["def5678"][0]
  618. assert text[start1:end1] == "[abc1234]", "First span should start at beginning"
  619. assert text[start2:end2] == "[def5678]", "Second span should end at end"
  620. # Empty text
  621. text = ""
  622. spans = extract_citation_spans(text)
  623. assert spans == {}, "Should return empty dictionary for empty text"
  624. # None input
  625. spans = extract_citation_spans(None)
  626. assert spans == {}, "Should return empty dictionary for None input"
  627. # Overlapping brackets
  628. text = "Text with overlapping [abc1234] brackets [def5678]."
  629. spans = extract_citation_spans(text)
  630. assert len(spans) == 2, "Should extract two spans correctly even with proximity"
  631. assert "abc1234" in spans, "First citation ID should be present"
  632. assert "def5678" in spans, "Second citation ID should be present"
  633. assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
  634. assert len(spans["def5678"]) == 1, "Should have one span for second citation"
  635. def test_core_citation_tracker(self):
  636. """Test the core CitationTracker class functionality."""
  637. tracker = CitationTracker()
  638. # Test initial state
  639. assert len(tracker.processed_spans) == 0, "Should start with empty citation spans"
  640. # Test adding a new span
  641. assert tracker.is_new_span("abc1234", (10, 20)), "First span should be considered new"
  642. assert "abc1234" in tracker.processed_spans, "Citation ID should be in processed_spans"
  643. assert (10, 20) in tracker.processed_spans["abc1234"], "Span should be recorded"
  644. # Test adding a duplicate span
  645. assert not tracker.is_new_span("abc1234", (10, 20)), "Duplicate span should not be considered new"
  646. assert len(tracker.processed_spans["abc1234"]) == 1, "Duplicate span should not be added again"
  647. # Test adding a new span for the same citation
  648. assert tracker.is_new_span("abc1234", (30, 40)), "Different span for same citation should be new"
  649. assert len(tracker.processed_spans["abc1234"]) == 2, "New span should be added"
  650. assert (30, 40) in tracker.processed_spans["abc1234"], "New span should be recorded"
  651. # Test get_all_spans
  652. all_spans = tracker.get_all_citation_spans()
  653. assert "abc1234" in all_spans, "Citation ID should be in all spans"
  654. assert len(all_spans["abc1234"]) == 2, "Should have 2 spans for the citation"
  655. def test_core_citation_tracker_edge_cases(self):
  656. """Test edge cases for the core CitationTracker class."""
  657. tracker = CitationTracker()
  658. # Test with empty or invalid inputs
  659. assert not tracker.is_new_span("", (10, 20)), "Empty citation ID should not be tracked"
  660. assert not tracker.is_new_span(None, (10, 20)), "None citation ID should not be tracked"
  661. assert tracker.is_new_span("abc1234", (-5, 20)), "Negative start position should be accepted"
  662. assert tracker.is_new_span("abc1234", (30, 20)), "End before start should be accepted (implementation dependent)"
  663. # Test overlapping spans
  664. assert tracker.is_new_span("def5678", (10, 30)), "First overlapping span should be new"
  665. assert tracker.is_new_span("def5678", (20, 40)), "Second overlapping span should be new"
  666. assert len(tracker.processed_spans["def5678"]) == 2, "Both overlapping spans should be recorded"
  667. # Test with very large spans
  668. assert tracker.is_new_span("large", (0, 10000)), "Very large span should be tracked"
  669. assert (0, 10000) in tracker.processed_spans["large"], "Large span should be recorded correctly"
  670. # Test get_all_spans with multiple citations
  671. all_spans = tracker.get_all_citation_spans()
  672. assert len(all_spans) >= 3, "Should have at least 3 different citation IDs"
  673. # Empty citation ID won't be included since we properly reject them in is_new_span
  674. def test_find_new_citation_spans(self):
  675. """Test the function that finds new citation spans in text."""
  676. tracker = CitationTracker()
  677. # First text with citations
  678. text = "This is a text with citation [abc1234]."
  679. new_spans1 = find_new_citation_spans(text, tracker)
  680. assert len(new_spans1) == 1, "Should find one new span"
  681. assert new_spans1[0][0] == "abc1234", "Citation ID should match"
  682. citation_id, start, end = new_spans1[0]
  683. assert citation_id in tracker.processed_spans, "Citation ID should be tracked"
  684. assert (start, end) in tracker.processed_spans[citation_id], "Span should be tracked"
  685. # Duplicate span in new text
  686. text2 = text # Same text with same citation
  687. new_spans2 = find_new_citation_spans(text2, tracker)
  688. assert new_spans2 == [], "Should not find duplicate spans"
  689. # Text with new citation
  690. text3 = "This is another text with a new citation [def5678]."
  691. new_spans3 = find_new_citation_spans(text3, tracker)
  692. assert len(new_spans3) == 1, "Should find one new span"
  693. assert new_spans3[0][0] == "def5678", "New citation ID should match"
  694. # Text with both old and new citations
  695. text4 = "Text with both [abc1234] and [ghi9012]."
  696. new_spans4 = find_new_citation_spans(text4, tracker)
  697. assert len(new_spans4) == 1, "Should only find the new span"
  698. assert new_spans4[0][0] == "ghi9012", "Only new citation ID should be found"
  699. def test_find_new_citation_spans_edge_cases(self):
  700. """Test edge cases for finding new citation spans."""
  701. tracker = CitationTracker()
  702. # Empty text
  703. new_spans1 = find_new_citation_spans("", tracker)
  704. assert new_spans1 == [], "Should return empty list for empty text"
  705. # Text without citations
  706. new_spans2 = find_new_citation_spans("This text has no citations or brackets.", tracker)
  707. assert new_spans2 == [], "Should return empty list for text without citations"
  708. # None input
  709. new_spans3 = find_new_citation_spans(None, tracker)
  710. assert new_spans3 == [], "Should handle None input gracefully and return empty list"
  711. # Multiple citations in one text
  712. text = "Text with multiple citations [abc1234] and [def5678] and [ghi9012]."
  713. new_spans = find_new_citation_spans(text, tracker)
  714. assert len(new_spans) == 3, "Should find three new spans"
  715. citation_ids = [span[0] for span in new_spans]
  716. assert "abc1234" in citation_ids, "First citation should be found"
  717. assert "def5678" in citation_ids, "Second citation should be found"
  718. assert "ghi9012" in citation_ids, "Third citation should be found"
  719. def test_performance_with_many_citations(self):
  720. """Test performance with a large number of citations."""
  721. # Create a text with 100 different citations
  722. citations = [f"cit{i:04d}" for i in range(100)]
  723. text = "Beginning of text. "
  724. for i, citation in enumerate(citations):
  725. text += f"Citation {i+1}: [{citation}]. "
  726. text += "End of text."
  727. # Extract all citations
  728. extracted = extract_citations(text)
  729. assert len(extracted) == 100, "Should extract all 100 citations"
  730. # Extract all spans
  731. spans = extract_citation_spans(text)
  732. assert len(spans) == 100, "Should extract all 100 spans"
  733. # Test find_new_citation_spans with a tracker
  734. tracker = CitationTracker()
  735. new_spans = find_new_citation_spans(text, tracker)
  736. assert len(new_spans) == 100, "Should find all 100 spans as new"
  737. # Test finding spans in chunks (simulating streaming)
  738. chunk_size = len(text) // 10
  739. tracker2 = CitationTracker()
  740. total_new_spans = 0
  741. for i in range(10):
  742. start = i * chunk_size
  743. end = start + chunk_size
  744. if i == 9: # Last chunk
  745. end = len(text)
  746. chunk = text[start:end]
  747. new_spans_in_chunk = find_new_citation_spans(chunk, tracker2, start_offset=start)
  748. total_new_spans += len(new_spans_in_chunk)
  749. # We might not get exactly 100 because citations could be split across chunks
  750. # But we should get a reasonable number
  751. assert total_new_spans > 50, "Should find majority of spans even in chunks"
  752. def test_streaming_citation_handling(self):
  753. """Test citation handling with simulated streaming updates."""
  754. tracker = CitationTracker()
  755. # Simulate a streaming scenario where text comes in chunks
  756. chunks = [
  757. "This is the first chunk ",
  758. "with no citations. This is the second chunk with a ",
  759. "citation [abc1234] and some more text. ",
  760. "This is the third chunk with another citation [def5678] ",
  761. "and the first citation again [abc1234] in a new position."
  762. ]
  763. all_text = ""
  764. total_spans_found = 0
  765. for i, chunk in enumerate(chunks):
  766. chunk_start = len(all_text)
  767. all_text += chunk
  768. # For streaming, we need to extract citation spans from the chunk
  769. # and check if they are new in the context of the accumulated text
  770. pattern = r'\[([\w]{7,8})\]'
  771. for match in re.finditer(pattern, chunk):
  772. citation_id = match.group(1)
  773. start = match.start() + chunk_start
  774. end = match.end() + chunk_start
  775. # Check if this span is new for this citation ID
  776. if tracker.is_new_span(citation_id, (start, end)):
  777. total_spans_found += 1
  778. # Check final state
  779. assert "abc1234" in tracker.processed_spans, "First citation should be tracked"
  780. assert "def5678" in tracker.processed_spans, "Second citation should be tracked"
  781. assert len(tracker.processed_spans["abc1234"]) == 2, "First citation should have 2 spans"
  782. assert len(tracker.processed_spans["def5678"]) == 1, "Second citation should have 1 span"
  783. assert total_spans_found == 3, "Should have found 3 spans in total"
  784. def test_malformed_citations(self):
  785. """Test handling of malformed or partial citations."""
  786. # Various malformed citation patterns
  787. text = """
  788. This text has citations with issues:
  789. - Missing end bracket [abc1234
  790. - Missing start bracket def5678]
  791. - Wrong format [abc123] (too short)
  792. - Wrong format [abcdefghi] (too long)
  793. - Valid citation [abc1234]
  794. - Empty brackets []
  795. - Non-alphanumeric [abc@123]
  796. """
  797. # Extract citations
  798. citations = extract_citations(text)
  799. assert len(citations) == 1, "Should only extract the one valid citation"
  800. assert citations[0] == "abc1234", "Valid citation should be extracted"
  801. # Extract spans
  802. spans = extract_citation_spans(text)
  803. assert len(spans) == 1, "Should only extract span for the valid citation"
  804. assert "abc1234" in spans, "Valid citation span should be extracted"
  805. # Test with the tracker
  806. tracker = CitationTracker()
  807. new_spans = find_new_citation_spans(text, tracker)
  808. assert len(new_spans) == 1, "Should only find one new valid citation span"
  809. assert new_spans[0][0] == "abc1234", "Valid citation should be found"
  810. assert len(tracker.processed_spans) == 1, "Should only track the valid citation"
  811. def find_new_citation_spans(text, tracker, start_offset=0):
  812. """Find new citation spans in text that haven't been processed yet."""
  813. if text is None or text == "":
  814. return []
  815. new_spans = []
  816. pattern = r'\[([\w]{7,8})\]'
  817. # Get citation IDs that have already been processed
  818. previously_seen_ids = set(tracker.processed_spans.keys())
  819. # Find all citations in the text
  820. for match in re.finditer(pattern, text):
  821. citation_id = match.group(1)
  822. start = match.start() + start_offset
  823. end = match.end() + start_offset
  824. # Filter out citation IDs we've seen before
  825. # For this test, we only want to return entirely new citation IDs
  826. if citation_id not in previously_seen_ids:
  827. # Check if this specific span is new
  828. if tracker.is_new_span(citation_id, (start, end)):
  829. new_spans.append((citation_id, start, end))
  830. return new_spans