conftest.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. # tests/conftest.py
  2. import os
  3. import pytest
  4. from core.base import AppConfig, DatabaseConfig, VectorQuantizationType
  5. from core.providers import NaClCryptoConfig, NaClCryptoProvider
  6. from core.providers.database.postgres import (
  7. PostgresChunksHandler,
  8. PostgresCollectionsHandler,
  9. PostgresConversationsHandler,
  10. PostgresDatabaseProvider,
  11. PostgresDocumentsHandler,
  12. PostgresGraphsHandler,
  13. PostgresLimitsHandler,
  14. PostgresPromptsHandler,
  15. )
  16. from core.providers.database.users import ( # Make sure this import is correct
  17. PostgresUserHandler, )
  18. TEST_DB_CONNECTION_STRING = os.environ.get(
  19. "TEST_DB_CONNECTION_STRING",
  20. "postgresql://postgres:postgres@localhost:5432/test_db",
  21. )
  22. @pytest.fixture
  23. async def db_provider():
  24. crypto_provider = NaClCryptoProvider(NaClCryptoConfig(app={}))
  25. db_config = DatabaseConfig(
  26. app=AppConfig(project_name="test_project"),
  27. provider="postgres",
  28. connection_string=TEST_DB_CONNECTION_STRING,
  29. postgres_configuration_settings={
  30. "max_connections": 10,
  31. "statement_cache_size": 100,
  32. },
  33. project_name="test_project",
  34. )
  35. dimension = 4
  36. quantization_type = VectorQuantizationType.FP32
  37. db_provider = PostgresDatabaseProvider(db_config, dimension,
  38. crypto_provider, quantization_type)
  39. await db_provider.initialize()
  40. yield db_provider
  41. # Teardown logic if needed
  42. await db_provider.close()
  43. @pytest.fixture
  44. def crypto_provider():
  45. # Provide a crypto provider fixture if needed separately
  46. return NaClCryptoProvider(NaClCryptoConfig(app={}))
  47. @pytest.fixture
  48. async def chunks_handler(db_provider):
  49. dimension = db_provider.dimension
  50. quantization_type = db_provider.quantization_type
  51. project_name = db_provider.project_name
  52. connection_manager = db_provider.connection_manager
  53. handler = PostgresChunksHandler(
  54. project_name=project_name,
  55. connection_manager=connection_manager,
  56. dimension=dimension,
  57. quantization_type=quantization_type,
  58. )
  59. await handler.create_tables()
  60. return handler
  61. @pytest.fixture
  62. async def collections_handler(db_provider):
  63. project_name = db_provider.project_name
  64. connection_manager = db_provider.connection_manager
  65. config = db_provider.config
  66. handler = PostgresCollectionsHandler(
  67. project_name=project_name,
  68. connection_manager=connection_manager,
  69. config=config,
  70. )
  71. await handler.create_tables()
  72. return handler
  73. @pytest.fixture
  74. async def conversations_handler(db_provider):
  75. project_name = db_provider.project_name
  76. connection_manager = db_provider.connection_manager
  77. handler = PostgresConversationsHandler(project_name, connection_manager)
  78. await handler.create_tables()
  79. return handler
  80. @pytest.fixture
  81. async def documents_handler(db_provider):
  82. dimension = db_provider.dimension
  83. project_name = db_provider.project_name
  84. connection_manager = db_provider.connection_manager
  85. handler = PostgresDocumentsHandler(
  86. project_name=project_name,
  87. connection_manager=connection_manager,
  88. dimension=dimension,
  89. )
  90. await handler.create_tables()
  91. return handler
  92. @pytest.fixture
  93. async def graphs_handler(db_provider):
  94. project_name = db_provider.project_name
  95. connection_manager = db_provider.connection_manager
  96. dimension = db_provider.dimension
  97. quantization_type = db_provider.quantization_type
  98. # If collections_handler is needed, you can depend on the collections_handler fixture
  99. # or pass None if it's optional.
  100. handler = PostgresGraphsHandler(
  101. project_name=project_name,
  102. connection_manager=connection_manager,
  103. dimension=dimension,
  104. quantization_type=quantization_type,
  105. collections_handler=
  106. None, # if needed, or await collections_handler fixture
  107. )
  108. await handler.create_tables()
  109. return handler
  110. @pytest.fixture
  111. async def limits_handler(db_provider):
  112. project_name = db_provider.project_name
  113. connection_manager = db_provider.connection_manager
  114. config = db_provider.config
  115. handler = PostgresLimitsHandler(
  116. project_name=project_name,
  117. connection_manager=connection_manager,
  118. config=config,
  119. )
  120. await handler.create_tables()
  121. # Optionally truncate
  122. await connection_manager.execute_query(
  123. f"TRUNCATE {handler._get_table_name('request_log')};")
  124. return handler
  125. @pytest.fixture
  126. async def users_handler(db_provider, crypto_provider):
  127. project_name = db_provider.project_name
  128. connection_manager = db_provider.connection_manager
  129. handler = PostgresUserHandler(
  130. project_name=project_name,
  131. connection_manager=connection_manager,
  132. crypto_provider=crypto_provider,
  133. )
  134. await handler.create_tables()
  135. # Optionally clean up users table before each test
  136. await connection_manager.execute_query(
  137. f"TRUNCATE {handler._get_table_name('users')} CASCADE;")
  138. await connection_manager.execute_query(
  139. f"TRUNCATE {handler._get_table_name('users_api_keys')} CASCADE;")
  140. return handler
  141. @pytest.fixture
  142. async def prompt_handler(db_provider):
  143. """Returns an instance of PostgresPromptsHandler, creating the necessary
  144. tables first."""
  145. # from core.providers.database.postgres_prompts import PostgresPromptsHandler
  146. project_name = db_provider.project_name
  147. connection_manager = db_provider.connection_manager
  148. handler = PostgresPromptsHandler(
  149. project_name=project_name,
  150. connection_manager=connection_manager,
  151. # You can specify a local prompt directory if desired
  152. prompt_directory=None,
  153. )
  154. # Create necessary tables and do initial prompt load
  155. await handler.create_tables()
  156. return handler
  157. @pytest.fixture
  158. async def graphs_handler(db_provider):
  159. project_name = db_provider.project_name
  160. connection_manager = db_provider.connection_manager
  161. dimension = db_provider.dimension
  162. quantization_type = db_provider.quantization_type
  163. # Optionally ensure 'collection_ids' column exists on your table(s), e.g.:
  164. create_col_sql = f"""
  165. ALTER TABLE "{project_name}"."graphs_entities"
  166. ADD COLUMN IF NOT EXISTS collection_ids UUID[] DEFAULT '{{}}';
  167. """
  168. await connection_manager.execute_query(create_col_sql)
  169. handler = PostgresGraphsHandler(
  170. project_name=project_name,
  171. connection_manager=connection_manager,
  172. dimension=dimension,
  173. quantization_type=quantization_type,
  174. collections_handler=None,
  175. )
  176. await handler.create_tables()
  177. return handler
  178. # Citation testing fixtures and utilities
  179. import json
  180. import re
  181. from unittest.mock import MagicMock, AsyncMock
  182. from typing import Tuple, Any, AsyncGenerator
  183. from core.base import Message, LLMChatCompletion, LLMChatCompletionChunk, GenerationConfig
  184. from core.utils import CitationTracker, SearchResultsCollector
  185. from core.agent.base import R2RStreamingAgent
  186. class MockLLMProvider:
  187. """Mock LLM provider for testing."""
  188. def __init__(self, response_content=None, citations=None):
  189. self.response_content = response_content or "This is a response"
  190. self.citations = citations or []
  191. async def aget_completion(self, messages, generation_config):
  192. """Mock synchronous completion."""
  193. content = self.response_content
  194. for citation in self.citations:
  195. content += f" [{citation}]"
  196. mock_response = MagicMock(spec=LLMChatCompletion)
  197. mock_response.choices = [MagicMock()]
  198. mock_response.choices[0].message = MagicMock()
  199. mock_response.choices[0].message.content = content
  200. mock_response.choices[0].finish_reason = "stop"
  201. return mock_response
  202. async def aget_completion_stream(self, messages, generation_config):
  203. """Mock streaming completion."""
  204. content = self.response_content
  205. for citation in self.citations:
  206. content += f" [{citation}]"
  207. # Simulate streaming by yielding one character at a time
  208. for i in range(len(content)):
  209. chunk = MagicMock(spec=LLMChatCompletionChunk)
  210. chunk.choices = [MagicMock()]
  211. chunk.choices[0].delta = MagicMock()
  212. chunk.choices[0].delta.content = content[i]
  213. chunk.choices[0].finish_reason = None
  214. yield chunk
  215. # Final chunk with finish_reason="stop"
  216. final_chunk = MagicMock(spec=LLMChatCompletionChunk)
  217. final_chunk.choices = [MagicMock()]
  218. final_chunk.choices[0].delta = MagicMock()
  219. final_chunk.choices[0].delta.content = ""
  220. final_chunk.choices[0].finish_reason = "stop"
  221. yield final_chunk
  222. class MockPromptsHandler:
  223. """Mock prompts handler for testing."""
  224. async def get_cached_prompt(self, prompt_key, inputs=None, *args, **kwargs):
  225. """Return a mock system prompt."""
  226. return "You are a helpful assistant that provides well-sourced information."
  227. class MockDatabaseProvider:
  228. """Mock database provider for testing."""
  229. def __init__(self):
  230. # Add a prompts_handler attribute to prevent AttributeError
  231. self.prompts_handler = MockPromptsHandler()
  232. async def acreate_conversation(self, *args, **kwargs):
  233. return {"id": "conv_12345"}
  234. async def aupdate_conversation(self, *args, **kwargs):
  235. return True
  236. async def acreate_message(self, *args, **kwargs):
  237. return {"id": "msg_12345"}
  238. class MockSearchResultsCollector:
  239. """Mock search results collector for testing."""
  240. def __init__(self, results=None):
  241. self.results = results or {}
  242. def find_by_short_id(self, short_id):
  243. return self.results.get(short_id, {
  244. "document_id": f"doc_{short_id}",
  245. "text": f"This is document text for {short_id}",
  246. "metadata": {"source": f"source_{short_id}"}
  247. })
  248. # Create a concrete implementation of R2RStreamingAgent for testing
  249. class MockR2RStreamingAgent(R2RStreamingAgent):
  250. """Mock streaming agent for testing that implements the abstract method."""
  251. # Regex pattern for citations, copied from the actual agent
  252. BRACKET_PATTERN = re.compile(r"\[([^\]]+)\]")
  253. SHORT_ID_PATTERN = re.compile(r"[A-Za-z0-9]{7,8}")
  254. def _register_tools(self):
  255. """Implement the abstract method with a no-op version."""
  256. pass
  257. async def _setup(self, system_instruction=None, *args, **kwargs):
  258. """Override _setup to simplify initialization and avoid external dependencies."""
  259. # Use a simple system message instead of fetching from database
  260. system_content = system_instruction or "You are a helpful assistant that provides well-sourced information."
  261. # Add system message to conversation
  262. await self.conversation.add_message(
  263. Message(role="system", content=system_content)
  264. )
  265. def _format_sse_event(self, event_type, data):
  266. """Format an SSE event manually."""
  267. return f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
  268. async def arun(
  269. self,
  270. system_instruction: str = None,
  271. messages: list[Message] = None,
  272. *args,
  273. **kwargs,
  274. ) -> AsyncGenerator[str, None]:
  275. """
  276. Simplified version of arun that focuses on citation handling for testing.
  277. """
  278. await self._setup(system_instruction)
  279. if messages:
  280. for m in messages:
  281. await self.conversation.add_message(m)
  282. # Initialize citation tracker
  283. citation_tracker = CitationTracker()
  284. citation_payloads = {}
  285. # Track streaming citations for final persistence
  286. self.streaming_citations = []
  287. # Get the LLM response with citations
  288. response_content = "This is a test response with citations"
  289. response_content += " [abc1234] [def5678]"
  290. # Yield an initial message event with the start of the text
  291. yield self._format_sse_event("message", {"content": response_content})
  292. # Manually extract and emit citation events
  293. # This is a simpler approach than the character-by-character approach
  294. citation_spans = extract_citation_spans(response_content)
  295. # Process the citations
  296. for cid, spans in citation_spans.items():
  297. for span in spans:
  298. # Check if the span is new and record it
  299. if citation_tracker.is_new_span(cid, span):
  300. # Look up the source document for this citation
  301. source_doc = self.search_results_collector.find_by_short_id(cid)
  302. # Create citation payload
  303. citation_payload = {
  304. "document_id": source_doc.get("document_id", f"doc_{cid}"),
  305. "text": source_doc.get("text", f"This is document text for {cid}"),
  306. "metadata": source_doc.get("metadata", {"source": f"source_{cid}"}),
  307. }
  308. # Store the payload by citation ID
  309. citation_payloads[cid] = citation_payload
  310. # Track for persistence
  311. self.streaming_citations.append({
  312. "id": cid,
  313. "span": {"start": span[0], "end": span[1]},
  314. "payload": citation_payload
  315. })
  316. # Emit citation event in the expected format
  317. citation_event = {
  318. "id": cid,
  319. "object": "citation",
  320. "span": {"start": span[0], "end": span[1]},
  321. "payload": citation_payload
  322. }
  323. yield self._format_sse_event("citation", citation_event)
  324. # Add assistant message with citation metadata to conversation
  325. await self.conversation.add_message(
  326. Message(
  327. role="assistant",
  328. content=response_content,
  329. metadata={"citations": self.streaming_citations}
  330. )
  331. )
  332. # Prepare consolidated citations for final answer
  333. consolidated_citations = []
  334. # Group citations by ID with all their spans
  335. for cid, spans in citation_tracker.get_all_spans().items():
  336. if cid in citation_payloads:
  337. consolidated_citations.append({
  338. "id": cid,
  339. "object": "citation",
  340. "spans": [{"start": s[0], "end": s[1]} for s in spans],
  341. "payload": citation_payloads[cid]
  342. })
  343. # Create and emit final answer event
  344. final_evt_payload = {
  345. "id": "msg_final",
  346. "object": "agent.final_answer",
  347. "generated_answer": response_content,
  348. "citations": consolidated_citations
  349. }
  350. # Manually format the final answer event
  351. yield self._format_sse_event("agent.final_answer", final_evt_payload)
  352. # Signal the end of the SSE stream
  353. yield "event: done\ndata: {}\n\n"
  354. @pytest.fixture
  355. def mock_streaming_agent():
  356. """Create a streaming agent with mocked dependencies."""
  357. # Create mock config
  358. config = MagicMock()
  359. config.stream = True
  360. config.max_iterations = 3
  361. # Create mock providers
  362. llm_provider = MockLLMProvider(
  363. response_content="This is a test response with citations",
  364. citations=["abc1234", "def5678"]
  365. )
  366. db_provider = MockDatabaseProvider()
  367. # Create agent with mocked dependencies using our concrete implementation
  368. agent = MockR2RStreamingAgent(
  369. database_provider=db_provider,
  370. llm_provider=llm_provider,
  371. config=config,
  372. rag_generation_config=GenerationConfig(model="test/model")
  373. )
  374. # Replace the search results collector with our mock
  375. agent.search_results_collector = MockSearchResultsCollector({
  376. "abc1234": {
  377. "document_id": "doc_abc1234",
  378. "text": "This is document text for abc1234",
  379. "metadata": {"source": "source_abc1234"}
  380. },
  381. "def5678": {
  382. "document_id": "doc_def5678",
  383. "text": "This is document text for def5678",
  384. "metadata": {"source": "source_def5678"}
  385. }
  386. })
  387. return agent
  388. async def collect_stream_output(stream):
  389. """Collect all output from a stream into a list."""
  390. output = []
  391. async for event in stream:
  392. output.append(event)
  393. return output
  394. from core.utils import extract_citation_spans, find_new_citation_spans