123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- # tests/conftest.py
- import os
- import pytest
- from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
- from core.providers import NaClCryptoConfig, NaClCryptoProvider
- from core.providers.database.postgres import (
- PostgresChunksHandler,
- PostgresCollectionsHandler,
- PostgresConversationsHandler,
- PostgresDatabaseProvider,
- PostgresDocumentsHandler,
- PostgresGraphsHandler,
- PostgresLimitsHandler,
- PostgresPromptsHandler,
- )
- from core.providers.database.users import ( # Make sure this import is correct
- PostgresUserHandler, )
- TEST_DB_CONNECTION_STRING = os.environ.get(
- "TEST_DB_CONNECTION_STRING",
- "postgresql://postgres:postgres@localhost:5432/test_db",
- )
- @pytest.fixture
- async def db_provider():
- crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
- db_config = DatabaseConfig(
- app=AppConfig(project_name="test_project"),
- provider="postgres",
- connection_string=TEST_DB_CONNECTION_STRING,
- postgres_configuration_settings={
- "max_connections": 10,
- "statement_cache_size": 100,
- },
- project_name="test_project",
- )
- dimension = 4
- quantization_type = VectorQuantizationType.FP32
- db_provider = PostgresDatabaseProvider(db_config, dimension,
- crypto_provider, quantization_type)
- await db_provider.initialize()
- yield db_provider
- # Teardown logic if needed
- await db_provider.close()
- @pytest.fixture
- def crypto_provider():
- # Provide a crypto provider fixture if needed separately
- return NaClCryptoProvider(NaClCryptoConfig(app={}))
- @pytest.fixture
- async def chunks_handler(db_provider):
- dimension = db_provider.dimension
- quantization_type = db_provider.quantization_type
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresChunksHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- quantization_type=quantization_type,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def collections_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- config = db_provider.config
- handler = PostgresCollectionsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- config=config,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def conversations_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresConversationsHandler(project_name, connection_manager)
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def documents_handler(db_provider):
- dimension = db_provider.dimension
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresDocumentsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def graphs_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- dimension = db_provider.dimension
- quantization_type = db_provider.quantization_type
- # If collections_handler is needed, you can depend on the collections_handler fixture
- # or pass None if it's optional.
- handler = PostgresGraphsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- quantization_type=quantization_type,
- collections_handler=
- None, # if needed, or await collections_handler fixture
- )
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def limits_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- config = db_provider.config
- handler = PostgresLimitsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- config=config,
- )
- await handler.create_tables()
- # Optionally truncate
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('request_log')};")
- return handler
- @pytest.fixture
- async def users_handler(db_provider, crypto_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresUserHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- crypto_provider=crypto_provider,
- )
- await handler.create_tables()
- # Optionally clean up users table before each test
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
- await connection_manager.execute_query(
- f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
- return handler
- @pytest.fixture
- async def prompt_handler(db_provider):
- """Returns an instance of PostgresPromptsHandler, creating the necessary
- tables first."""
- # from core.providers.database.postgres_prompts import PostgresPromptsHandler
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- handler = PostgresPromptsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- # You can specify a local prompt directory if desired
- prompt_directory=None,
- )
- # Create necessary tables and do initial prompt load
- await handler.create_tables()
- return handler
- @pytest.fixture
- async def graphs_handler(db_provider):
- project_name = db_provider.project_name
- connection_manager = db_provider.connection_manager
- dimension = db_provider.dimension
- quantization_type = db_provider.quantization_type
- # Optionally ensure 'collection_ids' column exists on your table(s), e.g.:
- create_col_sql = f"""
- ALTER TABLE "{project_name}"."graphs_entities"
- ADD COLUMN IF NOT EXISTS collection_ids UUID[] DEFAULT '{{}}';
- """
- await connection_manager.execute_query(create_col_sql)
- handler = PostgresGraphsHandler(
- project_name=project_name,
- connection_manager=connection_manager,
- dimension=dimension,
- quantization_type=quantization_type,
- collections_handler=None,
- )
- await handler.create_tables()
- return handler
- # Citation testing fixtures and utilities
- import json
- import re
- from unittest.mock import MagicMock, AsyncMock
- from typing import Tuple, Any, AsyncGenerator
- from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
- from core.utils import CitationTracker, SearchResultsCollector
- from core.agent.base import R2RStreamingAgent
- class MockLLMProvider:
- """Mock LLM provider for testing."""
- def __init__(self, response_content=None, citations=None):
- self.response_content = response_content or "This is a response"
- self.citations = citations or []
- async def aget_completion(self, messages, generation_config):
- """Mock synchronous completion."""
- content = self.response_content
- for citation in self.citations:
- content += f" [{citation}]"
- mock_response = MagicMock(spec=LLMChatCompletion)
- mock_response.choices = [MagicMock()]
- mock_response.choices[0].message = MagicMock()
- mock_response.choices[0].message.content = content
- mock_response.choices[0].finish_reason = "stop"
- return mock_response
- async def aget_completion_stream(self, messages, generation_config):
- """Mock streaming completion."""
- content = self.response_content
- for citation in self.citations:
- content += f" [{citation}]"
- # Simulate streaming by yielding one character at a time
- for i in range(len(content)):
- chunk = MagicMock(spec=LLMChatCompletionChunk)
- chunk.choices = [MagicMock()]
- chunk.choices[0].delta = MagicMock()
- chunk.choices[0].delta.content = content[i]
- chunk.choices[0].finish_reason = None
- yield chunk
- # Final chunk with finish_reason="stop"
- final_chunk = MagicMock(spec=LLMChatCompletionChunk)
- final_chunk.choices = [MagicMock()]
- final_chunk.choices[0].delta = MagicMock()
- final_chunk.choices[0].delta.content = ""
- final_chunk.choices[0].finish_reason = "stop"
- yield final_chunk
- class MockPromptsHandler:
- """Mock prompts handler for testing."""
- async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
- """Return a mock system prompt."""
- return "You are a helpful assistant that provides well-sourced information."
- class MockDatabaseProvider:
- """Mock database provider for testing."""
- def __init__(self):
- # Add a prompts_handler attribute to prevent AttributeError
- self.prompts_handler = MockPromptsHandler()
- async def acreate_conversation(self, *args, **kwargs):
- return {"id": "conv_12345"}
- async def aupdate_conversation(self, *args, **kwargs):
- return True
- async def acreate_message(self, *args, **kwargs):
- return {"id": "msg_12345"}
- class MockSearchResultsCollector:
- """Mock search results collector for testing."""
- def __init__(self, results=None):
- self.results = results or {}
- def find_by_short_id(self, short_id):
- return self.results.get(short_id, {
- "document_id": f"doc_{short_id}",
- "text": f"This is document text for {short_id}",
- "metadata": {"source": f"source_{short_id}"}
- })
- # Create a concrete implementation of R2RStreamingAgent for testing
- class MockR2RStreamingAgent(R2RStreamingAgent):
- """Mock streaming agent for testing that implements the abstract method."""
- # Regex pattern for citations, copied from the actual agent
- BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
- SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
- def _register_tools(self):
- """Implement the abstract method with a no-op version."""
- pass
- async def _setup(self, system_instruction=None, *args, **kwargs):
- """Override _setup to simplify initialization and avoid external dependencies."""
- # Use a simple system message instead of fetching from database
- system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
- # Add system message to conversation
- await self.conversation.add_message(
- Message(role="system", content=system_content)
- )
- def _format_sse_event(self, event_type, data):
- """Format an SSE event manually."""
- return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
- async def arun(
- self,
- system_instruction: str = None,
- messages: list[Message] = None,
- *args,
- **kwargs,
- ) -> AsyncGenerator[str, None]:
- """
- Simplified version of arun that focuses on citation handling for testing.
- """
- await self._setup(system_instruction)
- if messages:
- for m in messages:
- await self.conversation.add_message(m)
- # Initialize citation tracker
- citation_tracker = CitationTracker()
- citation_payloads = {}
- # Track streaming citations for final persistence
- self.streaming_citations = []
- # Get the LLM response with citations
- response_content = "This is a test response with citations"
- response_content += " [abc1234] [def5678]"
- # Yield an initial message event with the start of the text
- yield self._format_sse_event("message", {"content": response_content})
- # Manually extract and emit citation events
- # This is a simpler approach than the character-by-character approach
- citation_spans = extract_citation_spans(response_content)
- # Process the citations
- for cid, spans in citation_spans.items():
- for span in spans:
- # Check if the span is new and record it
- if citation_tracker.is_new_span(cid, span):
- # Look up the source document for this citation
- source_doc = self.search_results_collector.find_by_short_id(cid)
- # Create citation payload
- citation_payload = {
- "document_id": source_doc.get("document_id", f"doc_{cid}"),
- "text": source_doc.get("text", f"This is document text for {cid}"),
- "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
- }
- # Store the payload by citation ID
- citation_payloads[cid] = citation_payload
- # Track for persistence
- self.streaming_citations.append({
- "id": cid,
- "span": {"start": span[0], "end": span[1]},
- "payload": citation_payload
- })
- # Emit citation event in the expected format
- citation_event = {
- "id": cid,
- "object": "citation",
- "span": {"start": span[0], "end": span[1]},
- "payload": citation_payload
- }
- yield self._format_sse_event("citation", citation_event)
- # Add assistant message with citation metadata to conversation
- await self.conversation.add_message(
- Message(
- role="assistant",
- content=response_content,
- metadata={"citations": self.streaming_citations}
- )
- )
- # Prepare consolidated citations for final answer
- consolidated_citations = []
- # Group citations by ID with all their spans
- for cid, spans in citation_tracker.get_all_spans().items():
- if cid in citation_payloads:
- consolidated_citations.append({
- "id": cid,
- "object": "citation",
- "spans": [{"start": s[0], "end": s[1]} for s in spans],
- "payload": citation_payloads[cid]
- })
- # Create and emit final answer event
- final_evt_payload = {
- "id": "msg_final",
- "object": "agent.final_answer",
- "generated_answer": response_content,
- "citations": consolidated_citations
- }
- # Manually format the final answer event
- yield self._format_sse_event("agent.final_answer", final_evt_payload)
- # Signal the end of the SSE stream
- yield "event: done\ndata: {}\n\n"
- @pytest.fixture
- def mock_streaming_agent():
- """Create a streaming agent with mocked dependencies."""
- # Create mock config
- config = MagicMock()
- config.stream = True
- config.max_iterations = 3
- # Create mock providers
- llm_provider = MockLLMProvider(
- response_content="This is a test response with citations",
- citations=["abc1234", "def5678"]
- )
- db_provider = MockDatabaseProvider()
- # Create agent with mocked dependencies using our concrete implementation
- agent = MockR2RStreamingAgent(
- database_provider=db_provider,
- llm_provider=llm_provider,
- config=config,
- rag_generation_config=GenerationConfig(model="test/model")
- )
- # Replace the search results collector with our mock
- agent.search_results_collector = MockSearchResultsCollector({
- "abc1234": {
- "document_id": "doc_abc1234",
- "text": "This is document text for abc1234",
- "metadata": {"source": "source_abc1234"}
- },
- "def5678": {
- "document_id": "doc_def5678",
- "text": "This is document text for def5678",
- "metadata": {"source": "source_def5678"}
- }
- })
- return agent
- async def collect_stream_output(stream):
- """Collect all output from a stream into a list."""
- output = []
- async for event in stream:
- output.append(event)
- return output
- from core.utils import extract_citation_spans, find_new_citation_spans
|