12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010 |
- """
- Unit tests for citation handling in retrieval functionality.
- """
- import pytest
- import re
- from unittest.mock import AsyncMock, MagicMock, patch
- from typing import Dict, List, Any, Optional
- # Import citation utilities from core.utils
- from core.utils import (
- extract_citations,
- extract_citation_spans,
- find_new_citation_spans,
- CitationTracker as CoreCitationTracker
- )
- class CitationTracker:
- """Simple citation tracker for testing."""
- def __init__(self):
- # Track which citation spans we've processed
- # Format: {citation_id: {(start, end), (start, end), ...}}
- self.processed_spans = {}
- self.citation_spans = {}
- def is_new_span(self, citation_id, span):
- """Check if this span is new and mark it as processed if it is."""
- # Handle invalid inputs
- if citation_id is None or citation_id == "" or span is None:
- return False
- # Initialize set for this citation ID if needed
- if citation_id not in self.processed_spans:
- self.processed_spans[citation_id] = set()
- # Check if we've seen this span before for this citation
- if span in self.processed_spans[citation_id]:
- return False
- # This is a new span, track it
- self.processed_spans[citation_id].add(span)
- # Also track by citation ID for easy lookup
- if citation_id not in self.citation_spans:
- self.citation_spans[citation_id] = []
- self.citation_spans[citation_id].append(span)
- return True
- def get_all_citation_spans(self):
- """Get all citation spans processed so far."""
- return {
- citation_id: spans
- for citation_id, spans in self.citation_spans.items()
- }
- class MockCitation:
- """Mock Citation class for testing."""
- def __init__(self, citation_id, chunk_id=None, document_id=None, text=None, metadata=None):
- self.citation_id = citation_id
- self.chunk_id = chunk_id or f"chunk-{citation_id}"
- self.document_id = document_id or f"doc-{citation_id}"
- self.text = text or f"Citation text for {citation_id}"
- self.metadata = metadata or {"source": f"source-{citation_id}"}
- self.spans = []
- @pytest.fixture
- def mock_providers():
- """Return a mocked providers object for testing."""
- class MockProviders:
- def __init__(self):
- # Mock the database
- self.database = AsyncMock()
- self.database.citations_handler = AsyncMock()
- self.database.citations_handler.get_citation = AsyncMock(
- side_effect=lambda citation_id: MockCitation(citation_id)
- )
- # Mock LLM
- self.llm = AsyncMock()
- self.llm.aget_completion = AsyncMock(
- return_value={"choices": [{"message": {"content": "Response with [abc1234] citation"}}]}
- )
- self.llm.aget_completion_stream = AsyncMock(
- return_value=iter([
- {"choices": [{"delta": {"content": "Response "}}]},
- {"choices": [{"delta": {"content": "with "}}]},
- {"choices": [{"delta": {"content": "[abc1234] "}}]},
- {"choices": [{"delta": {"content": "citation"}}]}
- ])
- )
- return MockProviders()
- @pytest.fixture
- def sample_chunk_results():
- """Return sample chunk results with citation metadata."""
- return [
- {
- "chunk_id": f"chunk-{i}",
- "document_id": f"doc-{i//2}",
- "text": f"This is chunk {i} with information about the topic.",
- "metadata": {
- "source": f"source-{i}",
- "citation_id": f"cite{i}"
- },
- "score": 0.95 - (i * 0.05),
- }
- for i in range(5)
- ]
- class TestCitationExtraction:
- """Tests for citation extraction functionality."""
- def test_extract_citations_basic(self):
- """Test basic citation extraction from text with standard format."""
- # Test function to extract citations
- def extract_citations(text):
- citation_pattern = r'\[([\w\d]+)\]'
- citations = re.findall(citation_pattern, text)
- return citations
- # Test cases
- test_cases = [
- (
- "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234].",
- ["abc1234"]
- ),
- (
- "According to Plato [xyz5678] and Aristotle [abc1234], philosophy is important.",
- ["xyz5678", "abc1234"]
- ),
- (
- "This text has no citations.",
- []
- ),
- (
- "Multiple citations in a row [abc1234][def5678][ghi9012] should all be found.",
- ["abc1234", "def5678", "ghi9012"]
- )
- ]
- # Run tests
- for text, expected_citations in test_cases:
- extracted = extract_citations(text)
- assert extracted == expected_citations
- def test_extract_citations_with_spans(self):
- """Test citation extraction with text spans."""
- # Test function to extract citations with spans
- def extract_citations_with_spans(text):
- citation_pattern = r'\[([\w\d]+)\]'
- citations_with_spans = []
- for match in re.finditer(citation_pattern, text):
- citation_id = match.group(1)
- start = match.start()
- end = match.end()
- # Get the context (text before and after the citation)
- context_start = max(0, start - 50)
- context_end = min(len(text), end + 50)
- context = text[context_start:context_end]
- citations_with_spans.append({
- "citation_id": citation_id,
- "start": start,
- "end": end,
- "context": context
- })
- return citations_with_spans
- # Test text
- text = (
- "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234]. "
- "According to Plato [xyz5678], the ideal state is described in The Republic. "
- "Socrates' method of questioning is demonstrated in many dialogues [ghi9012]."
- )
- # Extract citations with spans
- extracted = extract_citations_with_spans(text)
- # Verify the correct number of citations was extracted
- assert len(extracted) == 3
- # Verify citation IDs are correct
- assert extracted[0]["citation_id"] == "abc1234"
- assert extracted[1]["citation_id"] == "xyz5678"
- assert extracted[2]["citation_id"] == "ghi9012"
- # Verify spans and context
- for citation in extracted:
- assert citation["start"] < citation["end"]
- assert text[citation["start"]:citation["end"]] == f"[{citation['citation_id']}]"
- assert citation["citation_id"] in citation["context"]
- def test_citation_extraction_edge_cases(self):
- """Test citation extraction with edge cases and malformed citations."""
- # Test function to extract citations that exactly matches the implementation in core.utils
- def extract_citations(text):
- # Handle None or empty input
- if text is None or text == "":
- return []
- # Match the core implementation pattern: 7-8 alphanumeric chars
- citation_pattern = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
- sids = []
- for match in citation_pattern.finditer(text):
- sid = match.group(1)
- sids.append(sid)
- return sids
- # Edge case tests
- test_cases = [
- (
- "Incomplete citation [abc1234", # Missing closing bracket
- [] # This would not match with the regular pattern
- ),
- (
- "Empty citation []", # Empty citation
- [] # This would match but capture an empty string
- ),
- (
- "Citation with special chars [abc-1234]", # Contains hyphen
- [] # Should not capture because hyphen is not allowed in the pattern
- ),
- (
- "Citation at the end of sentence[abcd1234].", # No space before citation
- ["abcd1234"] # Should still capture
- ),
- (
- "Valid citation [abc1234]", # Valid citation
- ["abc1234"] # Should capture
- ),
- (
- "Text with [short] but no valid citation format.", # 'short' is only 5 chars, too short
- [] # Should not extract non-citation brackets with wrong length
- ),
- (
- "Text with [abc123] (too short) and [abcdefghi] (too long).",
- [] # Should not extract brackets with wrong length
- ),
- (
- "Text with [abc-1234] has the right length but contains special characters.",
- [] # Should not extract brackets with special characters
- ),
- ]
- # Run tests
- for text, expected_citations in test_cases:
- extracted = extract_citations(text)
- assert extracted == expected_citations
- def test_citation_sanitization(self):
- """Test sanitization of citation IDs."""
- # Function to sanitize citation IDs
- def sanitize_citation_id(citation_id):
- # Remove any non-alphanumeric characters
- return re.sub(r'[^a-zA-Z0-9]', '', citation_id)
- # Test cases
- test_cases = [
- ("abc1234", "abc1234"), # Already clean
- ("abc-1234", "abc1234"), # Contains hyphen
- ("abc.1234", "abc1234"), # Contains period
- ("abc_1234", "abc1234"), # Contains underscore
- ("abc 1234", "abc1234"), # Contains space
- ]
- # Run tests
- for input_id, expected_id in test_cases:
- sanitized = sanitize_citation_id(input_id)
- assert sanitized == expected_id
- class TestCitationTracker:
- """Tests for citation tracking functionality."""
- def test_citation_tracker_init(self):
- """Test initialization of citation tracker."""
- tracker = CitationTracker()
- assert hasattr(tracker, 'processed_spans')
- assert hasattr(tracker, 'citation_spans')
- assert isinstance(tracker.processed_spans, dict)
- assert isinstance(tracker.citation_spans, dict)
- assert len(tracker.processed_spans) == 0
- assert len(tracker.citation_spans) == 0
- def test_is_new_span(self):
- """Test is_new_span method."""
- tracker = CitationTracker()
- # First occurrence should be new
- assert tracker.is_new_span("abc1234", (10, 18)) is True
- # Same span should not be new anymore
- assert tracker.is_new_span("abc1234", (10, 18)) is False
- # Different span for same citation should be new
- assert tracker.is_new_span("abc1234", (30, 38)) is True
- # Different citation ID should be new
- assert tracker.is_new_span("def5678", (10, 18)) is True
- def test_get_all_citation_spans(self):
- """Test get_all_citation_spans method."""
- tracker = CitationTracker()
- # Add some spans
- tracker.is_new_span("abc1234", (10, 18))
- tracker.is_new_span("abc1234", (30, 38))
- tracker.is_new_span("def5678", (50, 58))
- # Get all spans
- all_spans = tracker.get_all_citation_spans()
- # Verify results
- assert "abc1234" in all_spans
- assert "def5678" in all_spans
- assert len(all_spans["abc1234"]) == 2
- assert len(all_spans["def5678"]) == 1
- assert (10, 18) in all_spans["abc1234"]
- assert (30, 38) in all_spans["abc1234"]
- assert (50, 58) in all_spans["def5678"]
- def test_citation_tracker_multiple_spans(self):
- """Test tracking multiple citation spans."""
- tracker = CitationTracker()
- # Sample text with multiple citations
- text = (
- "Aristotle discussed virtue ethics in his Nicomachean Ethics [abc1234]. "
- "Later in the same work [abc1234], he expanded on this concept. "
- "According to Plato [def5678], the ideal state is described in The Republic."
- )
- # Extract and track citations
- citation_pattern = r'\[([\w\d]+)\]'
- for match in re.finditer(citation_pattern, text):
- citation_id = match.group(1)
- start = match.start()
- end = match.end()
- tracker.is_new_span(citation_id, (start, end))
- # Verify tracking
- all_spans = tracker.get_all_citation_spans()
- assert len(all_spans["abc1234"]) == 2
- assert len(all_spans["def5678"]) == 1
- class TestCitationStreamingEvents:
- """Tests for citation events during streaming."""
- def test_emit_citation_event(self):
- """Test emitting a citation event during streaming."""
- # Create a mock agent
- class MockAgent:
- def __init__(self):
- self.emitted_events = []
- def emit_event(self, event):
- self.emitted_events.append(event)
- agent = MockAgent()
- # Function to emit a citation event
- def emit_citation_event(agent, citation_id, start, end, text_context):
- event = {
- "type": "citation",
- "data": {
- "citation_id": citation_id,
- "start": start,
- "end": end,
- "text_context": text_context
- }
- }
- agent.emit_event(event)
- # Emit an event
- emit_citation_event(agent, "abc1234", 10, 18, "text with [abc1234] citation")
- # Verify event
- assert len(agent.emitted_events) == 1
- event = agent.emitted_events[0]
- assert event["type"] == "citation"
- assert event["data"]["citation_id"] == "abc1234"
- assert event["data"]["start"] == 10
- assert event["data"]["end"] == 18
- def test_citation_tracking_during_streaming(self):
- """Test tracking citations during streaming."""
- # Create a mock agent with citation tracker
- class MockAgent:
- def __init__(self):
- self.emitted_events = []
- self.citation_tracker = CitationTracker()
- def emit_event(self, event):
- self.emitted_events.append(event)
- agent = MockAgent()
- # Function to process streaming text and emit citation events
- def process_streaming_text(agent, text, start_offset=0):
- # Extract citations
- citation_pattern = r'\[([\w\d]+)\]'
- for match in re.finditer(citation_pattern, text):
- citation_id = match.group(1)
- start = match.start() + start_offset
- end = match.end() + start_offset
- # Check if this is a new span
- if agent.citation_tracker.is_new_span(citation_id, (start, end)):
- # Get context
- context_start = max(0, match.start() - 10)
- context_end = min(len(text), match.end() + 10)
- context = text[context_start:context_end]
- # Emit event
- event = {
- "type": "citation",
- "data": {
- "citation_id": citation_id,
- "start": start,
- "end": end,
- "text_context": context
- }
- }
- agent.emit_event(event)
- # Process streaming text in chunks
- chunks = [
- "Aristotle discussed virtue ethics ",
- "in his Nicomachean Ethics [abc1234]. ",
- "According to Plato [def5678], ",
- "the ideal state is described in The Republic. ",
- "Later, Aristotle also mentioned [abc1234] this concept."
- ]
- offset = 0
- for chunk in chunks:
- process_streaming_text(agent, chunk, offset)
- offset += len(chunk)
- # Verify events and tracking
- assert len(agent.emitted_events) == 3 # 3 citations total (2 abc1234, 1 def5678)
- # Verify citation IDs in events
- citation_ids = [event["data"]["citation_id"] for event in agent.emitted_events]
- assert citation_ids.count("abc1234") == 2
- assert citation_ids.count("def5678") == 1
- # Verify tracker state
- all_spans = agent.citation_tracker.get_all_citation_spans()
- assert len(all_spans["abc1234"]) == 2
- assert len(all_spans["def5678"]) == 1
- class TestRAGWithCitations:
- """Tests for RAG functionality with citations."""
- @pytest.mark.asyncio
- async def test_rag_with_citation_metadata(self, mock_providers, sample_chunk_results):
- """Test RAG with citation metadata in search results."""
- # Function to build a RAG prompt with citations
- def build_rag_prompt_with_citations(query, search_results):
- context = ""
- citation_metadata = {}
- for i, result in enumerate(search_results):
- # Extract citation information
- citation_id = result.get("metadata", {}).get("citation_id")
- if citation_id:
- # Add to context with citation marker
- context += f"\n[{i+1}] {result['text']} [{citation_id}]"
- # Store metadata
- citation_metadata[citation_id] = {
- "document_id": result["document_id"],
- "chunk_id": result["chunk_id"],
- "metadata": result.get("metadata", {})
- }
- else:
- context += f"\n[{i+1}] {result['text']}"
- prompt = f"Question: {query}\n\nContext:{context}\n\nPlease answer the question based on the provided context."
- return prompt, citation_metadata
- # Build prompt
- query = "What is the main concept?"
- prompt, citation_metadata = build_rag_prompt_with_citations(query, sample_chunk_results)
- # Verify prompt contains citations
- for i in range(5):
- assert f"[cite{i}]" in prompt
- # Verify metadata is stored
- assert len(citation_metadata) == 5
- for i in range(5):
- assert f"cite{i}" in citation_metadata
- assert citation_metadata[f"cite{i}"]["document_id"] == f"doc-{i//2}"
- assert citation_metadata[f"cite{i}"]["chunk_id"] == f"chunk-{i}"
- @pytest.mark.asyncio
- async def test_rag_response_with_citations(self, mock_providers, sample_chunk_results):
- """Test generating a RAG response with citations."""
- # Function to generate RAG response with citations
- async def generate_rag_response_with_citations(query, search_results):
- # Build prompt with citations
- context = ""
- citation_metadata = {}
- for i, result in enumerate(search_results):
- citation_id = result.get("metadata", {}).get("citation_id")
- if citation_id:
- context += f"\n[{i+1}] {result['text']} [{citation_id}]"
- citation_metadata[citation_id] = {
- "document_id": result["document_id"],
- "chunk_id": result["chunk_id"],
- "metadata": result.get("metadata", {})
- }
- else:
- context += f"\n[{i+1}] {result['text']}"
- prompt = f"Question: {query}\n\nContext:{context}\n\nPlease answer the question based on the provided context."
- # Generate response (mocked)
- # In real implementation, this would call the LLM
- mock_providers.llm.aget_completion.return_value = {
- "choices": [{
- "message": {
- "content": "The main concept is explained in [cite0] and further elaborated in [cite2]."
- }
- }]
- }
- response = await mock_providers.llm.aget_completion(prompt=prompt)
- content = response["choices"][0]["message"]["content"]
- return content, citation_metadata
- # Generate response
- query = "What is the main concept?"
- response, citation_metadata = await generate_rag_response_with_citations(query, sample_chunk_results)
- # Verify response contains citations
- assert "[cite0]" in response
- assert "[cite2]" in response
- # Extract citations from response
- def extract_citations_from_response(text):
- citation_pattern = r'\[([\w\d]+)\]'
- citations = re.findall(citation_pattern, text)
- return citations
- citations = extract_citations_from_response(response)
- assert "cite0" in citations
- assert "cite2" in citations
- @pytest.mark.asyncio
- async def test_consolidate_citations_in_final_answer(self, mock_providers):
- """Test consolidating citations in the final answer."""
- # Create a citation tracker with some spans
- tracker = CitationTracker()
- tracker.is_new_span("cite0", (10, 18))
- tracker.is_new_span("cite0", (30, 38))
- tracker.is_new_span("cite2", (50, 58))
- # Create citation metadata
- citation_metadata = {
- "cite0": {
- "document_id": "doc-0",
- "chunk_id": "chunk-0",
- "metadata": {"source": "source-0", "title": "Document 0"}
- },
- "cite2": {
- "document_id": "doc-1",
- "chunk_id": "chunk-2",
- "metadata": {"source": "source-2", "title": "Document 1"}
- }
- }
- # Function to consolidate citations
- def consolidate_citations(response_text, citation_tracker, citation_metadata):
- # Get all citations from the tracker
- all_citation_spans = citation_tracker.get_all_citation_spans()
- # Build consolidated citations
- consolidated_citations = {}
- for citation_id, spans in all_citation_spans.items():
- if citation_id in citation_metadata:
- metadata = citation_metadata[citation_id]
- consolidated_citations[citation_id] = {
- "spans": spans,
- "document_id": metadata["document_id"],
- "chunk_id": metadata["chunk_id"],
- "metadata": metadata["metadata"]
- }
- # Return the response with consolidated citations
- return {
- "response": response_text,
- "citations": consolidated_citations
- }
- # Test response
- response_text = "The main concept is explained in [cite0] and further elaborated in [cite2]."
- # Consolidate citations
- result = consolidate_citations(response_text, tracker, citation_metadata)
- # Verify result
- assert "response" in result
- assert "citations" in result
- assert result["response"] == response_text
- # Verify consolidated citations
- assert "cite0" in result["citations"]
- assert "cite2" in result["citations"]
- assert len(result["citations"]["cite0"]["spans"]) == 2
- assert len(result["citations"]["cite2"]["spans"]) == 1
- assert result["citations"]["cite0"]["document_id"] == "doc-0"
- assert result["citations"]["cite2"]["document_id"] == "doc-1"
- class TestCitationUtils:
- """Tests for citation utility functions."""
- def test_extract_citations(self):
- """Test that citations are correctly extracted from text."""
- # Simple case with one citation
- text = "This is a test with a citation [abc1234]."
- citations = extract_citations(text)
- assert citations == ["abc1234"], "Should extract a single citation ID"
- # Multiple citations
- text = "First citation [abc1234] and second citation [def5678]."
- citations = extract_citations(text)
- assert citations == ["abc1234", "def5678"], "Should extract multiple citation IDs"
- # Repeated citations
- text = "Same citation twice [abc1234] and again [abc1234]."
- citations = extract_citations(text)
- assert len(citations) == 2, "Should extract duplicate citation IDs"
- assert citations == ["abc1234", "abc1234"], "Should preserve order of citations"
- def test_extract_citations_edge_cases(self):
- """Test edge cases for citation extraction."""
- # Define local extract_citations for testing that follows the core implementation
- def local_extract_citations(text):
- # Handle None or empty input
- if text is None or text == "":
- return []
- # Match the core implementation pattern: 7-8 alphanumeric chars
- citation_pattern = re.compile(r"\[([A-Za-z0-9]{7,8})\]")
- sids = []
- for match in citation_pattern.finditer(text):
- sid = match.group(1)
- sids.append(sid)
- return sids
- # Citations at beginning or end of text
- text = "[abc1234] at the beginning and at the end [def5678]"
- citations = local_extract_citations(text)
- assert citations == ["abc1234", "def5678"], "Should extract citations at beginning and end"
- # Empty text
- text = ""
- citations = local_extract_citations(text)
- assert citations == [], "Should handle empty text gracefully"
- # None input
- citations = local_extract_citations(None)
- assert citations == [], "Should handle None input gracefully"
- # Text with brackets but no valid citation format
- text = "Text with [short] but no valid citation format."
- citations = local_extract_citations(text)
- assert citations == [], "Should not extract non-citation brackets (too short)"
- # Text with brackets but wrong length
- text = "Text with [abc123] (too short) and [abcdefghi] (too long)."
- citations = local_extract_citations(text)
- assert citations == [], "Should not extract brackets with wrong length"
- # Text with brackets that have correct length but non-alphanumeric chars
- text = "Text with [abc-1234] has the right length but contains special characters."
- citations = local_extract_citations(text)
- assert citations == [], "Should not extract brackets with special characters"
- # Text with close brackets only
- text = "Text with close brackets only]."
- citations = local_extract_citations(text)
- assert citations == [], "Should not extract when only close brackets present"
- def test_extract_citation_spans(self):
- """Test that citation spans are correctly extracted with positions."""
- # Simple case with one citation
- text = "This is a test with a citation [abc1234]."
- spans = extract_citation_spans(text)
- assert len(spans) == 1, "Should extract one citation ID"
- assert "abc1234" in spans, "Citation ID should be a key in the dictionary"
- assert len(spans["abc1234"]) == 1, "Should have one span for this citation"
- start, end = spans["abc1234"][0]
- assert text[start:end] == "[abc1234]", "Span positions should be correct"
- # Multiple citations
- text = "First citation [abc1234] and second citation [def5678]."
- spans = extract_citation_spans(text)
- assert len(spans) == 2, "Should extract two citation IDs"
- assert "abc1234" in spans, "First citation ID should be present"
- assert "def5678" in spans, "Second citation ID should be present"
- assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
- assert len(spans["def5678"]) == 1, "Should have one span for second citation"
- start1, end1 = spans["abc1234"][0]
- start2, end2 = spans["def5678"][0]
- assert text[start1:end1] == "[abc1234]", "First span positions should be correct"
- assert text[start2:end2] == "[def5678]", "Second span positions should be correct"
- def test_extract_citation_spans_edge_cases(self):
- """Test edge cases for citation span extraction."""
- # Citations at beginning or end of text
- text = "[abc1234] at the beginning and at the end [def5678]"
- spans = extract_citation_spans(text)
- assert len(spans) == 2, "Should extract two spans"
- assert "abc1234" in spans, "First citation ID should be present"
- assert "def5678" in spans, "Second citation ID should be present"
- assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
- assert len(spans["def5678"]) == 1, "Should have one span for second citation"
- start1, end1 = spans["abc1234"][0]
- start2, end2 = spans["def5678"][0]
- assert text[start1:end1] == "[abc1234]", "First span should start at beginning"
- assert text[start2:end2] == "[def5678]", "Second span should end at end"
- # Empty text
- text = ""
- spans = extract_citation_spans(text)
- assert spans == {}, "Should return empty dictionary for empty text"
- # None input
- spans = extract_citation_spans(None)
- assert spans == {}, "Should return empty dictionary for None input"
- # Overlapping brackets
- text = "Text with overlapping [abc1234] brackets [def5678]."
- spans = extract_citation_spans(text)
- assert len(spans) == 2, "Should extract two spans correctly even with proximity"
- assert "abc1234" in spans, "First citation ID should be present"
- assert "def5678" in spans, "Second citation ID should be present"
- assert len(spans["abc1234"]) == 1, "Should have one span for first citation"
- assert len(spans["def5678"]) == 1, "Should have one span for second citation"
- def test_core_citation_tracker(self):
- """Test the core CitationTracker class functionality."""
- tracker = CitationTracker()
- # Test initial state
- assert len(tracker.processed_spans) == 0, "Should start with empty citation spans"
- # Test adding a new span
- assert tracker.is_new_span("abc1234", (10, 20)), "First span should be considered new"
- assert "abc1234" in tracker.processed_spans, "Citation ID should be in processed_spans"
- assert (10, 20) in tracker.processed_spans["abc1234"], "Span should be recorded"
- # Test adding a duplicate span
- assert not tracker.is_new_span("abc1234", (10, 20)), "Duplicate span should not be considered new"
- assert len(tracker.processed_spans["abc1234"]) == 1, "Duplicate span should not be added again"
- # Test adding a new span for the same citation
- assert tracker.is_new_span("abc1234", (30, 40)), "Different span for same citation should be new"
- assert len(tracker.processed_spans["abc1234"]) == 2, "New span should be added"
- assert (30, 40) in tracker.processed_spans["abc1234"], "New span should be recorded"
- # Test get_all_spans
- all_spans = tracker.get_all_citation_spans()
- assert "abc1234" in all_spans, "Citation ID should be in all spans"
- assert len(all_spans["abc1234"]) == 2, "Should have 2 spans for the citation"
- def test_core_citation_tracker_edge_cases(self):
- """Test edge cases for the core CitationTracker class."""
- tracker = CitationTracker()
- # Test with empty or invalid inputs
- assert not tracker.is_new_span("", (10, 20)), "Empty citation ID should not be tracked"
- assert not tracker.is_new_span(None, (10, 20)), "None citation ID should not be tracked"
- assert tracker.is_new_span("abc1234", (-5, 20)), "Negative start position should be accepted"
- assert tracker.is_new_span("abc1234", (30, 20)), "End before start should be accepted (implementation dependent)"
- # Test overlapping spans
- assert tracker.is_new_span("def5678", (10, 30)), "First overlapping span should be new"
- assert tracker.is_new_span("def5678", (20, 40)), "Second overlapping span should be new"
- assert len(tracker.processed_spans["def5678"]) == 2, "Both overlapping spans should be recorded"
- # Test with very large spans
- assert tracker.is_new_span("large", (0, 10000)), "Very large span should be tracked"
- assert (0, 10000) in tracker.processed_spans["large"], "Large span should be recorded correctly"
- # Test get_all_spans with multiple citations
- all_spans = tracker.get_all_citation_spans()
- assert len(all_spans) >= 3, "Should have at least 3 different citation IDs"
- # Empty citation ID won't be included since we properly reject them in is_new_span
- def test_find_new_citation_spans(self):
- """Test the function that finds new citation spans in text."""
- tracker = CitationTracker()
- # First text with citations
- text = "This is a text with citation [abc1234]."
- new_spans1 = find_new_citation_spans(text, tracker)
- assert len(new_spans1) == 1, "Should find one new span"
- assert new_spans1[0][0] == "abc1234", "Citation ID should match"
- citation_id, start, end = new_spans1[0]
- assert citation_id in tracker.processed_spans, "Citation ID should be tracked"
- assert (start, end) in tracker.processed_spans[citation_id], "Span should be tracked"
- # Duplicate span in new text
- text2 = text # Same text with same citation
- new_spans2 = find_new_citation_spans(text2, tracker)
- assert new_spans2 == [], "Should not find duplicate spans"
- # Text with new citation
- text3 = "This is another text with a new citation [def5678]."
- new_spans3 = find_new_citation_spans(text3, tracker)
- assert len(new_spans3) == 1, "Should find one new span"
- assert new_spans3[0][0] == "def5678", "New citation ID should match"
- # Text with both old and new citations
- text4 = "Text with both [abc1234] and [ghi9012]."
- new_spans4 = find_new_citation_spans(text4, tracker)
- assert len(new_spans4) == 1, "Should only find the new span"
- assert new_spans4[0][0] == "ghi9012", "Only new citation ID should be found"
- def test_find_new_citation_spans_edge_cases(self):
- """Test edge cases for finding new citation spans."""
- tracker = CitationTracker()
- # Empty text
- new_spans1 = find_new_citation_spans("", tracker)
- assert new_spans1 == [], "Should return empty list for empty text"
- # Text without citations
- new_spans2 = find_new_citation_spans("This text has no citations or brackets.", tracker)
- assert new_spans2 == [], "Should return empty list for text without citations"
- # None input
- new_spans3 = find_new_citation_spans(None, tracker)
- assert new_spans3 == [], "Should handle None input gracefully and return empty list"
- # Multiple citations in one text
- text = "Text with multiple citations [abc1234] and [def5678] and [ghi9012]."
- new_spans = find_new_citation_spans(text, tracker)
- assert len(new_spans) == 3, "Should find three new spans"
- citation_ids = [span[0] for span in new_spans]
- assert "abc1234" in citation_ids, "First citation should be found"
- assert "def5678" in citation_ids, "Second citation should be found"
- assert "ghi9012" in citation_ids, "Third citation should be found"
- def test_performance_with_many_citations(self):
- """Test performance with a large number of citations."""
- # Create a text with 100 different citations
- citations = [f"cit{i:04d}" for i in range(100)]
- text = "Beginning of text. "
- for i, citation in enumerate(citations):
- text += f"Citation {i+1}: [{citation}]. "
- text += "End of text."
- # Extract all citations
- extracted = extract_citations(text)
- assert len(extracted) == 100, "Should extract all 100 citations"
- # Extract all spans
- spans = extract_citation_spans(text)
- assert len(spans) == 100, "Should extract all 100 spans"
- # Test find_new_citation_spans with a tracker
- tracker = CitationTracker()
- new_spans = find_new_citation_spans(text, tracker)
- assert len(new_spans) == 100, "Should find all 100 spans as new"
- # Test finding spans in chunks (simulating streaming)
- chunk_size = len(text) // 10
- tracker2 = CitationTracker()
- total_new_spans = 0
- for i in range(10):
- start = i * chunk_size
- end = start + chunk_size
- if i == 9: # Last chunk
- end = len(text)
- chunk = text[start:end]
- new_spans_in_chunk = find_new_citation_spans(chunk, tracker2, start_offset=start)
- total_new_spans += len(new_spans_in_chunk)
- # We might not get exactly 100 because citations could be split across chunks
- # But we should get a reasonable number
- assert total_new_spans > 50, "Should find majority of spans even in chunks"
- def test_streaming_citation_handling(self):
- """Test citation handling with simulated streaming updates."""
- tracker = CitationTracker()
- # Simulate a streaming scenario where text comes in chunks
- chunks = [
- "This is the first chunk ",
- "with no citations. This is the second chunk with a ",
- "citation [abc1234] and some more text. ",
- "This is the third chunk with another citation [def5678] ",
- "and the first citation again [abc1234] in a new position."
- ]
- all_text = ""
- total_spans_found = 0
- for i, chunk in enumerate(chunks):
- chunk_start = len(all_text)
- all_text += chunk
- # For streaming, we need to extract citation spans from the chunk
- # and check if they are new in the context of the accumulated text
- pattern = r'\[([\w]{7,8})\]'
- for match in re.finditer(pattern, chunk):
- citation_id = match.group(1)
- start = match.start() + chunk_start
- end = match.end() + chunk_start
- # Check if this span is new for this citation ID
- if tracker.is_new_span(citation_id, (start, end)):
- total_spans_found += 1
- # Check final state
- assert "abc1234" in tracker.processed_spans, "First citation should be tracked"
- assert "def5678" in tracker.processed_spans, "Second citation should be tracked"
- assert len(tracker.processed_spans["abc1234"]) == 2, "First citation should have 2 spans"
- assert len(tracker.processed_spans["def5678"]) == 1, "Second citation should have 1 span"
- assert total_spans_found == 3, "Should have found 3 spans in total"
- def test_malformed_citations(self):
- """Test handling of malformed or partial citations."""
- # Various malformed citation patterns
- text = """
- This text has citations with issues:
- - Missing end bracket [abc1234
- - Missing start bracket def5678]
- - Wrong format [abc123] (too short)
- - Wrong format [abcdefghi] (too long)
- - Valid citation [abc1234]
- - Empty brackets []
- - Non-alphanumeric [abc@123]
- """
- # Extract citations
- citations = extract_citations(text)
- assert len(citations) == 1, "Should only extract the one valid citation"
- assert citations[0] == "abc1234", "Valid citation should be extracted"
- # Extract spans
- spans = extract_citation_spans(text)
- assert len(spans) == 1, "Should only extract span for the valid citation"
- assert "abc1234" in spans, "Valid citation span should be extracted"
- # Test with the tracker
- tracker = CitationTracker()
- new_spans = find_new_citation_spans(text, tracker)
- assert len(new_spans) == 1, "Should only find one new valid citation span"
- assert new_spans[0][0] == "abc1234", "Valid citation should be found"
- assert len(tracker.processed_spans) == 1, "Should only track the valid citation"
- def find_new_citation_spans(text, tracker, start_offset=0):
- """Find new citation spans in text that haven't been processed yet."""
- if text is None or text == "":
- return []
- new_spans = []
- pattern = r'\[([\w]{7,8})\]'
- # Get citation IDs that have already been processed
- previously_seen_ids = set(tracker.processed_spans.keys())
- # Find all citations in the text
- for match in re.finditer(pattern, text):
- citation_id = match.group(1)
- start = match.start() + start_offset
- end = match.end() + start_offset
- # Filter out citation IDs we've seen before
- # For this test, we only want to return entirely new citation IDs
- if citation_id not in previously_seen_ids:
- # Check if this specific span is new
- if tracker.is_new_span(citation_id, (start, end)):
- new_spans.append((citation_id, start, end))
- return new_spans
|