test_rag_processing.py 55 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355
  1. """
  2. Unit tests for RAG (Retrieval-Augmented Generation) processing functionality.
  3. """
  4. import pytest
  5. from unittest.mock import AsyncMock, MagicMock, patch, call
  6. from typing import Dict, List, Any, Optional
  7. # Import core classes related to RAG prompt handling
  8. from core.base import Message, SearchSettings
  9. @pytest.fixture
  10. def mock_search_results():
  11. """Return mock search results for testing prompt construction."""
  12. return {
  13. "chunk_search_results": [
  14. {
  15. "chunk_id": f"chunk-{i}",
  16. "document_id": f"doc-{i//2}",
  17. "text": f"This is search result {i} about Aristotle's philosophy.",
  18. "metadata": {
  19. "source": f"source-{i}",
  20. "title": f"Document {i//2}",
  21. "page": i+1
  22. },
  23. "score": 0.95 - (i * 0.05),
  24. }
  25. for i in range(5)
  26. ]
  27. }
  28. @pytest.fixture
  29. def mock_providers():
  30. """Create mock providers for testing."""
  31. providers = AsyncMock()
  32. providers.llm = AsyncMock()
  33. providers.llm.aget_completion = AsyncMock(
  34. return_value={"choices": [{"message": {"content": "LLM generated response"}}]}
  35. )
  36. providers.llm.aget_completion_stream = AsyncMock(
  37. return_value=iter([{"choices": [{"delta": {"content": "Streamed chunk"}}]}])
  38. )
  39. providers.database = AsyncMock()
  40. providers.database.prompts_handler = AsyncMock()
  41. providers.database.prompts_handler.get_cached_prompt = AsyncMock(
  42. return_value="System prompt template with {{context}} placeholder"
  43. )
  44. return providers
  45. class TestRAGPromptBuilding:
  46. """Tests for RAG prompt construction."""
  47. @pytest.mark.asyncio
  48. async def test_rag_prompt_construction(self, mock_providers, mock_search_results):
  49. """Test RAG prompt construction with search results."""
  50. class RAGPromptBuilder:
  51. def __init__(self, providers):
  52. self.providers = providers
  53. async def build_prompt(self, query, search_results, system_prompt_template_id=None, include_metadata=True):
  54. # Simple implementation that handles search results
  55. chunks = search_results.get("chunk_search_results", [])
  56. context = ""
  57. for i, chunk in enumerate(chunks):
  58. # Format the chunk text
  59. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  60. # Add metadata if requested
  61. if include_metadata:
  62. metadata_items = []
  63. for key, value in chunk.get("metadata", {}).items():
  64. if key not in ["embedding"]: # Skip non-user-friendly fields
  65. metadata_items.append(f"{key}: {value}")
  66. if metadata_items:
  67. metadata_str = ", ".join(metadata_items)
  68. chunk_text += f" ({metadata_str})"
  69. context += chunk_text + "\n\n"
  70. return [
  71. {"role": "system", "content": f"System prompt with context:\n\n{context}"},
  72. {"role": "user", "content": query}
  73. ]
  74. # Create a RAG prompt builder
  75. builder = RAGPromptBuilder(providers=mock_providers)
  76. # Call the build method
  77. query = "What did Aristotle say about ethics?"
  78. messages = await builder.build_prompt(
  79. query=query,
  80. search_results=mock_search_results,
  81. system_prompt_template_id="default_rag_prompt",
  82. include_metadata=True
  83. )
  84. # Check that the messages list was constructed properly
  85. assert len(messages) > 0
  86. # Find the system message
  87. system_message = next((m for m in messages if m["role"] == "system"), None)
  88. assert system_message is not None, "System message should be present"
  89. # Check that context was injected into system message
  90. assert "search result" in system_message["content"], "System message should contain search results"
  91. # Check that metadata was included
  92. assert "source" in system_message["content"] or "title" in system_message["content"], \
  93. "System message should contain metadata when include_metadata=True"
  94. # Find the user message
  95. user_message = next((m for m in messages if m["role"] == "user"), None)
  96. assert user_message is not None, "User message should be present"
  97. assert user_message["content"] == query, "User message should contain the query"
  98. @pytest.mark.asyncio
  99. async def test_rag_prompt_construction_without_metadata(self, mock_providers, mock_search_results):
  100. """Test RAG prompt construction without metadata."""
  101. class RAGPromptBuilder:
  102. def __init__(self, providers):
  103. self.providers = providers
  104. async def build_prompt(self, query, search_results, system_prompt_template_id=None, include_metadata=True):
  105. # Simple implementation that handles search results
  106. chunks = search_results.get("chunk_search_results", [])
  107. context = ""
  108. for i, chunk in enumerate(chunks):
  109. # Format the chunk text
  110. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  111. # Add metadata if requested
  112. if include_metadata:
  113. metadata_items = []
  114. for key, value in chunk.get("metadata", {}).items():
  115. if key not in ["embedding"]: # Skip non-user-friendly fields
  116. metadata_items.append(f"{key}: {value}")
  117. if metadata_items:
  118. metadata_str = ", ".join(metadata_items)
  119. chunk_text += f" ({metadata_str})"
  120. context += chunk_text + "\n\n"
  121. return [
  122. {"role": "system", "content": f"System prompt with context:\n\n{context}"},
  123. {"role": "user", "content": query}
  124. ]
  125. # Create a RAG prompt builder
  126. builder = RAGPromptBuilder(providers=mock_providers)
  127. # Call the build method without metadata
  128. query = "What did Aristotle say about ethics?"
  129. messages = await builder.build_prompt(
  130. query=query,
  131. search_results=mock_search_results,
  132. system_prompt_template_id="default_rag_prompt",
  133. include_metadata=False
  134. )
  135. # Find the system message
  136. system_message = next((m for m in messages if m["role"] == "system"), None)
  137. # Ensure metadata is not included
  138. for term in ["source", "title", "page"]:
  139. assert term not in system_message["content"].lower(), \
  140. f"System message should not contain metadata term '{term}' when include_metadata=False"
  141. @pytest.mark.asyncio
  142. async def test_rag_prompt_with_task_prompt(self, mock_providers, mock_search_results):
  143. """Test RAG prompt construction with a task prompt."""
  144. class RAGPromptBuilder:
  145. def __init__(self, providers):
  146. self.providers = providers
  147. async def build_prompt(self, query, search_results, system_prompt_template_id=None, task_prompt=None):
  148. # Simple implementation that handles search results
  149. chunks = search_results.get("chunk_search_results", [])
  150. context = ""
  151. for i, chunk in enumerate(chunks):
  152. # Format the chunk text
  153. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  154. context += chunk_text + "\n\n"
  155. if task_prompt:
  156. context += f"\n\nTask: {task_prompt}"
  157. return [
  158. {"role": "system", "content": f"System prompt with context:\n\n{context}"},
  159. {"role": "user", "content": query}
  160. ]
  161. # Create a RAG prompt builder
  162. builder = RAGPromptBuilder(providers=mock_providers)
  163. # Call the build method with a task prompt
  164. query = "What did Aristotle say about ethics?"
  165. task_prompt = "Summarize the information and provide key points only"
  166. messages = await builder.build_prompt(
  167. query=query,
  168. search_results=mock_search_results,
  169. system_prompt_template_id="default_rag_prompt",
  170. task_prompt=task_prompt
  171. )
  172. # Find the messages
  173. system_message = next((m for m in messages if m["role"] == "system"), None)
  174. user_message = next((m for m in messages if m["role"] == "user"), None)
  175. # Check that task prompt was incorporated
  176. assert task_prompt in system_message["content"] or task_prompt in user_message["content"], \
  177. "Task prompt should be incorporated into the messages"
  178. @pytest.mark.asyncio
  179. async def test_rag_prompt_with_conversation_history(self, mock_providers, mock_search_results):
  180. """Test RAG prompt construction with conversation history."""
  181. class RAGPromptBuilder:
  182. def __init__(self, providers):
  183. self.providers = providers
  184. async def build_prompt(self, query, search_results, system_prompt_template_id=None, conversation_history=None):
  185. # Simple implementation that handles search results
  186. chunks = search_results.get("chunk_search_results", [])
  187. context = ""
  188. for i, chunk in enumerate(chunks):
  189. # Format the chunk text
  190. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  191. context += chunk_text + "\n\n"
  192. messages = [
  193. {"role": "system", "content": f"System prompt with context:\n\n{context}"}
  194. ]
  195. # Add conversation history if provided
  196. if conversation_history:
  197. messages.extend(conversation_history)
  198. else:
  199. # Only add the query as a separate message if no conversation history
  200. messages.append({"role": "user", "content": query})
  201. return messages
  202. # Create a RAG prompt builder
  203. builder = RAGPromptBuilder(providers=mock_providers)
  204. # Setup conversation history
  205. conversation_history = [
  206. {"role": "user", "content": "Tell me about Aristotle"},
  207. {"role": "assistant", "content": "Aristotle was a Greek philosopher."},
  208. {"role": "user", "content": "What about his ethics?"}
  209. ]
  210. # The last message in conversation history is the query
  211. query = conversation_history[-1]["content"]
  212. messages = await builder.build_prompt(
  213. query=query,
  214. search_results=mock_search_results,
  215. system_prompt_template_id="default_rag_prompt",
  216. conversation_history=conversation_history
  217. )
  218. # Check that all conversation messages are included
  219. history_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
  220. assert len(history_messages) == len(conversation_history), \
  221. "All conversation history messages should be included"
  222. # Check that the conversation history is preserved in the correct order
  223. for i, msg in enumerate(history_messages):
  224. assert msg["role"] == conversation_history[i]["role"]
  225. assert msg["content"] == conversation_history[i]["content"]
  226. @pytest.mark.asyncio
  227. async def test_rag_prompt_with_citations(self, mock_providers, mock_search_results):
  228. """Test RAG prompt construction with citation information."""
  229. class RAGPromptBuilder:
  230. def __init__(self, providers):
  231. self.providers = providers
  232. async def build_prompt(self, query, search_results, system_prompt_template_id=None, include_citations=True):
  233. # Simple implementation that handles search results
  234. chunks = search_results.get("chunk_search_results", [])
  235. context = ""
  236. for i, chunk in enumerate(chunks):
  237. # Format the chunk text
  238. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  239. # Add citation marker if requested
  240. citation_id = chunk.get("metadata", {}).get("citation_id")
  241. if include_citations and citation_id:
  242. chunk_text += f" [{citation_id}]"
  243. context += chunk_text + "\n\n"
  244. # Include instructions about citations
  245. citation_instructions = ""
  246. if include_citations:
  247. citation_instructions = "\n\nWhen referring to the context, include citation markers like [cit0] to attribute information to its source."
  248. return [
  249. {"role": "system", "content": f"System prompt with context:\n\n{context}{citation_instructions}"},
  250. {"role": "user", "content": query}
  251. ]
  252. # Add citation metadata to search results
  253. for i, result in enumerate(mock_search_results["chunk_search_results"]):
  254. result["metadata"]["citation_id"] = f"cit-{i}"
  255. # Create a RAG prompt builder
  256. builder = RAGPromptBuilder(providers=mock_providers)
  257. # Call the build method with citations enabled
  258. query = "What did Aristotle say about ethics?"
  259. messages = await builder.build_prompt(
  260. query=query,
  261. search_results=mock_search_results,
  262. system_prompt_template_id="default_rag_prompt",
  263. include_citations=True
  264. )
  265. # Find the system message
  266. system_message = next((m for m in messages if m["role"] == "system"), None)
  267. # Check that citation markers are included in the context
  268. assert any(f"[cit-{i}]" in system_message["content"] for i in range(5)), \
  269. "Citation markers should be included in the context"
  270. # Check for citation instruction in the prompt
  271. assert "citation" in system_message["content"].lower(), \
  272. "System message should include instructions about using citations"
  273. @pytest.mark.asyncio
  274. async def test_rag_custom_system_prompt(self, mock_providers, mock_search_results):
  275. """Test RAG prompt construction with a custom system prompt."""
  276. class RAGPromptBuilder:
  277. def __init__(self, providers):
  278. self.providers = providers
  279. async def build_prompt(self, query, search_results, system_prompt_template_id=None):
  280. # Simple implementation that handles search results
  281. chunks = search_results.get("chunk_search_results", [])
  282. context = ""
  283. for i, chunk in enumerate(chunks):
  284. # Format the chunk text
  285. chunk_text = f"[{i+1}] {chunk.get('text', '')}"
  286. context += chunk_text + "\n\n"
  287. # Get the custom system prompt template
  288. custom_prompt = "Custom system prompt with {{context}} and some instructions"
  289. if system_prompt_template_id:
  290. # In a real implementation, this would fetch the template from a database
  291. custom_prompt = f"Custom system prompt for {system_prompt_template_id} with {{{{context}}}}"
  292. # Replace the context placeholder with actual context
  293. system_content = custom_prompt.replace("{{context}}", context)
  294. return [
  295. {"role": "system", "content": system_content},
  296. {"role": "user", "content": query}
  297. ]
  298. # Create a custom system prompt template
  299. custom_prompt = "Custom system prompt with {{context}} and some instructions"
  300. # Create a RAG prompt builder
  301. builder = RAGPromptBuilder(providers=mock_providers)
  302. # Call the build method with a custom system prompt template ID
  303. query = "What did Aristotle say about ethics?"
  304. messages = await builder.build_prompt(
  305. query=query,
  306. search_results=mock_search_results,
  307. system_prompt_template_id="custom_template_id"
  308. )
  309. # Find the system message
  310. system_message = next((m for m in messages if m["role"] == "system"), None)
  311. # Check that the custom prompt was used
  312. assert "Custom system prompt" in system_message["content"], \
  313. "System message should use the custom prompt template"
  314. # Check that context was still injected
  315. assert "search result" in system_message["content"], \
  316. "Context should still be injected into custom prompt"
  317. class TestRAGProcessing:
  318. """Tests for RAG processing and generation."""
  319. @pytest.mark.asyncio
  320. async def test_rag_generation(self, mock_providers, mock_search_results):
  321. """Test generating a response using RAG."""
  322. class RAGProcessor:
  323. def __init__(self, providers):
  324. self.providers = providers
  325. self.prompt_builder = MagicMock()
  326. self.prompt_builder.build_prompt = AsyncMock(
  327. return_value=[
  328. {"role": "system", "content": "System prompt with context"},
  329. {"role": "user", "content": "What did Aristotle say about ethics?"}
  330. ]
  331. )
  332. async def generate(self, query, search_results, **kwargs):
  333. # Build the prompt
  334. messages = await self.prompt_builder.build_prompt(
  335. query=query,
  336. search_results=search_results,
  337. **kwargs
  338. )
  339. # Generate a response
  340. response = await self.providers.llm.aget_completion(messages=messages)
  341. return response["choices"][0]["message"]["content"]
  342. # Create the processor
  343. processor = RAGProcessor(mock_providers)
  344. # Generate a response
  345. query = "What did Aristotle say about ethics?"
  346. response = await processor.generate(
  347. query=query,
  348. search_results=mock_search_results
  349. )
  350. # Verify the LLM was called
  351. mock_providers.llm.aget_completion.assert_called_once()
  352. # Check the response
  353. assert response == "LLM generated response"
  354. @pytest.mark.asyncio
  355. async def test_rag_streaming(self, mock_providers, mock_search_results):
  356. """Test streaming a response using RAG."""
  357. class RAGProcessor:
  358. def __init__(self, providers):
  359. self.providers = providers
  360. self.prompt_builder = MagicMock()
  361. self.prompt_builder.build_prompt = AsyncMock(
  362. return_value=[
  363. {"role": "system", "content": "System prompt with context"},
  364. {"role": "user", "content": "What did Aristotle say about ethics?"}
  365. ]
  366. )
  367. async def generate_stream(self, query, search_results, **kwargs):
  368. # Build the prompt
  369. messages = await self.prompt_builder.build_prompt(
  370. query=query,
  371. search_results=search_results,
  372. **kwargs
  373. )
  374. # Generate a streaming response
  375. stream = await self.providers.llm.aget_completion_stream(messages=messages)
  376. return stream
  377. # Create a mock stream
  378. class MockStream:
  379. def __init__(self, chunks):
  380. self.chunks = chunks
  381. self.index = 0
  382. def __aiter__(self):
  383. return self
  384. async def __anext__(self):
  385. if self.index >= len(self.chunks):
  386. raise StopAsyncIteration
  387. chunk = self.chunks[self.index]
  388. self.index += 1
  389. return chunk
  390. # Configure the LLM mock to return an async iterable stream
  391. mock_stream = MockStream([
  392. {"choices": [{"delta": {"content": "This"}}]},
  393. {"choices": [{"delta": {"content": " is"}}]},
  394. {"choices": [{"delta": {"content": " a"}}]},
  395. {"choices": [{"delta": {"content": " test"}}]},
  396. {"choices": [{"delta": {"content": " response."}}]}
  397. ])
  398. mock_providers.llm.aget_completion_stream = AsyncMock(return_value=mock_stream)
  399. # Create the processor
  400. processor = RAGProcessor(mock_providers)
  401. # Generate a streaming response
  402. query = "What did Aristotle say about ethics?"
  403. stream = await processor.generate_stream(
  404. query=query,
  405. search_results=mock_search_results
  406. )
  407. # Verify the LLM streaming method was called
  408. mock_providers.llm.aget_completion_stream.assert_called_once()
  409. # Process the stream
  410. chunks = []
  411. async for chunk in stream:
  412. chunks.append(chunk)
  413. # Verify chunks were received
  414. assert len(chunks) == 5, "Should receive all 5 chunks"
  415. assert chunks[0]["choices"][0]["delta"]["content"] == "This", "First chunk content should match"
  416. assert chunks[-1]["choices"][0]["delta"]["content"] == " response.", "Last chunk content should match"
  417. @pytest.mark.asyncio
  418. async def test_rag_with_different_provider_models(self, mock_providers, mock_search_results):
  419. """Test RAG with different provider models."""
  420. class RAGProcessor:
  421. def __init__(self, providers):
  422. self.providers = providers
  423. self.prompt_builder = MagicMock()
  424. self.prompt_builder.build_prompt = AsyncMock(
  425. return_value=[
  426. {"role": "system", "content": "System prompt with context"},
  427. {"role": "user", "content": "What did Aristotle say about ethics?"}
  428. ]
  429. )
  430. async def generate(self, query, search_results, model=None, **kwargs):
  431. # Build the prompt
  432. messages = await self.prompt_builder.build_prompt(
  433. query=query,
  434. search_results=search_results,
  435. **kwargs
  436. )
  437. # Generate a response with the specified model
  438. response = await self.providers.llm.aget_completion(
  439. messages=messages,
  440. model=model
  441. )
  442. return response["choices"][0]["message"]["content"]
  443. # Create the processor
  444. processor = RAGProcessor(mock_providers)
  445. # Generate responses with different models
  446. query = "What did Aristotle say about ethics?"
  447. models = ["gpt-4", "claude-3-opus", "gemini-pro"]
  448. for model in models:
  449. await processor.generate(
  450. query=query,
  451. search_results=mock_search_results,
  452. model=model
  453. )
  454. # Verify the LLM was called with the correct model
  455. call_kwargs = mock_providers.llm.aget_completion.call_args[1]
  456. assert call_kwargs["model"] == model
  457. # Reset the mock for the next iteration
  458. mock_providers.llm.aget_completion.reset_mock()
  459. class TestRAGContextFormatting:
  460. """Tests for formatting context in RAG prompts."""
  461. def test_default_context_formatting(self, mock_search_results):
  462. """Test the default formatting of context from search results."""
  463. # Function to format context
  464. def format_context(search_results, include_metadata=True):
  465. context = ""
  466. for i, result in enumerate(search_results["chunk_search_results"]):
  467. # Format the chunk text
  468. chunk_text = f"[{i+1}] {result['text']}"
  469. # Add metadata if requested
  470. if include_metadata:
  471. metadata_items = []
  472. for key, value in result.get("metadata", {}).items():
  473. if key not in ["embedding"]: # Skip non-user-friendly fields
  474. metadata_items.append(f"{key}: {value}")
  475. if metadata_items:
  476. metadata_str = ", ".join(metadata_items)
  477. chunk_text += f" ({metadata_str})"
  478. context += chunk_text + "\n\n"
  479. return context.strip()
  480. # Format context with metadata
  481. context_with_metadata = format_context(mock_search_results)
  482. # Check formatting
  483. assert "[1]" in context_with_metadata
  484. assert "source" in context_with_metadata
  485. assert "title" in context_with_metadata
  486. # Format context without metadata
  487. context_without_metadata = format_context(mock_search_results, include_metadata=False)
  488. # Check formatting
  489. assert "[1]" in context_without_metadata
  490. assert "source" not in context_without_metadata
  491. assert "title" not in context_without_metadata
  492. def test_numbered_list_context_formatting(self, mock_search_results):
  493. """Test numbered list formatting of context."""
  494. # Function to format context as a numbered list
  495. def format_context_numbered_list(search_results):
  496. context_items = []
  497. for i, result in enumerate(search_results["chunk_search_results"]):
  498. context_items.append(f"{i+1}. {result['text']}")
  499. return "\n".join(context_items)
  500. # Format context
  501. context = format_context_numbered_list(mock_search_results)
  502. # Check formatting
  503. assert "1. " in context
  504. assert "2. " in context
  505. assert "3. " in context
  506. assert "4. " in context
  507. assert "5. " in context
  508. def test_source_attribution_context_formatting(self, mock_search_results):
  509. """Test context formatting with source attribution."""
  510. # Function to format context with source attribution
  511. def format_context_with_sources(search_results):
  512. context_items = []
  513. for result in search_results["chunk_search_results"]:
  514. source = result.get("metadata", {}).get("source", "Unknown source")
  515. title = result.get("metadata", {}).get("title", "Unknown title")
  516. context_items.append(f"From {source} ({title}):\n{result['text']}")
  517. return "\n\n".join(context_items)
  518. # Format context
  519. context = format_context_with_sources(mock_search_results)
  520. # Check formatting
  521. assert "From source-0" in context
  522. assert "Document 0" in context
  523. assert "From source-1" in context
  524. def test_citation_marker_context_formatting(self, mock_search_results):
  525. """Test context formatting with citation markers."""
  526. # Add citation IDs to search results
  527. for i, result in enumerate(mock_search_results["chunk_search_results"]):
  528. result["metadata"]["citation_id"] = f"cit{i}"
  529. # Function to format context with citation markers
  530. def format_context_with_citations(search_results):
  531. context_items = []
  532. for i, result in enumerate(search_results["chunk_search_results"]):
  533. citation_id = result.get("metadata", {}).get("citation_id")
  534. text = result["text"]
  535. if citation_id:
  536. context_items.append(f"[{i+1}] {text} [{citation_id}]")
  537. else:
  538. context_items.append(f"[{i+1}] {text}")
  539. return "\n\n".join(context_items)
  540. # Format context
  541. context = format_context_with_citations(mock_search_results)
  542. # Check formatting
  543. assert "[cit0]" in context
  544. assert "[cit1]" in context
  545. assert "[cit2]" in context
  546. class TestRAGErrorHandling:
  547. """Tests for handling errors in RAG processing."""
  548. @pytest.mark.asyncio
  549. async def test_rag_with_empty_search_results(self, mock_providers):
  550. """Test RAG behavior with empty search results."""
  551. class RAGPromptBuilder:
  552. def __init__(self, providers):
  553. self.providers = providers
  554. async def build_prompt(self, query, search_results, system_prompt_template_id=None):
  555. # Simple implementation that handles empty results gracefully
  556. if not search_results.get("chunk_search_results"):
  557. return [
  558. {"role": "system", "content": "No relevant information was found for your query."},
  559. {"role": "user", "content": query}
  560. ]
  561. return []
  562. # Create a RAG prompt builder
  563. builder = RAGPromptBuilder(providers=mock_providers)
  564. # Setup empty search results
  565. empty_search_results = {"chunk_search_results": []}
  566. # Call the build method with empty results
  567. query = "What did Aristotle say about ethics?"
  568. messages = await builder.build_prompt(
  569. query=query,
  570. search_results=empty_search_results,
  571. system_prompt_template_id="default_rag_prompt"
  572. )
  573. # Find the system message
  574. system_message = next((m for m in messages if m["role"] == "system"), None)
  575. # Check that the system message handles empty results gracefully
  576. assert system_message is not None, "System message should be present even with empty results"
  577. assert "no relevant information" in system_message["content"].lower(), \
  578. "System message should indicate that no relevant information was found"
  579. @pytest.mark.asyncio
  580. async def test_rag_with_malformed_search_results(self, mock_providers):
  581. """Test RAG behavior with malformed search results."""
  582. class RAGPromptBuilder:
  583. def __init__(self, providers):
  584. self.providers = providers
  585. async def build_prompt(self, query, search_results, system_prompt_template_id=None):
  586. # Handle malformed results by including whatever is available
  587. chunks = search_results.get("chunk_search_results", [])
  588. context = ""
  589. for chunk in chunks:
  590. # Handle missing fields gracefully
  591. text = chunk.get("text", "No text content")
  592. context += text + "\n\n"
  593. return [
  594. {"role": "system", "content": f"Context:\n{context}\n\nBased on the above context, answer the following question."},
  595. {"role": "user", "content": query}
  596. ]
  597. # Create a RAG prompt builder
  598. builder = RAGPromptBuilder(providers=mock_providers)
  599. # Setup malformed search results (missing required fields)
  600. malformed_search_results = {
  601. "chunk_search_results": [
  602. {
  603. # Missing chunk_id, document_id
  604. "text": "Malformed result without required fields"
  605. # Missing metadata
  606. }
  607. ]
  608. }
  609. # Call the build method with malformed results
  610. query = "What did Aristotle say about ethics?"
  611. messages = await builder.build_prompt(
  612. query=query,
  613. search_results=malformed_search_results,
  614. system_prompt_template_id="default_rag_prompt"
  615. )
  616. # Find the system message
  617. system_message = next((m for m in messages if m["role"] == "system"), None)
  618. # Check that the system message handles malformed results gracefully
  619. assert system_message is not None, "System message should be present even with malformed results"
  620. assert "Malformed result" in system_message["content"], \
  621. "The text content should still be included"
  622. @pytest.mark.asyncio
  623. async def test_rag_with_llm_error_recovery(self, mock_providers, mock_search_results):
  624. """Test RAG recovery from LLM errors."""
  625. class RAGProcessorWithErrorRecovery:
  626. def __init__(self, providers):
  627. self.providers = providers
  628. self.prompt_builder = MagicMock()
  629. self.prompt_builder.build_prompt = AsyncMock(
  630. return_value=[
  631. {"role": "system", "content": "System prompt with context"},
  632. {"role": "user", "content": "What did Aristotle say about ethics?"}
  633. ]
  634. )
  635. # Configure the LLM mock to fail on first call, succeed on second
  636. self.providers.llm.aget_completion = AsyncMock(side_effect=[
  637. Exception("LLM API error"),
  638. {"choices": [{"message": {"content": "Fallback response after error"}}]}
  639. ])
  640. async def generate_with_error_recovery(self, query, search_results, **kwargs):
  641. # Build the prompt
  642. messages = await self.prompt_builder.build_prompt(
  643. query=query,
  644. search_results=search_results,
  645. **kwargs
  646. )
  647. # Try with primary model
  648. try:
  649. response = await self.providers.llm.aget_completion(
  650. messages=messages,
  651. model="primary_model"
  652. )
  653. return response["choices"][0]["message"]["content"]
  654. except Exception as e:
  655. # On error, try with fallback model
  656. response = await self.providers.llm.aget_completion(
  657. messages=messages,
  658. model="fallback_model"
  659. )
  660. return response["choices"][0]["message"]["content"]
  661. # Create the processor
  662. processor = RAGProcessorWithErrorRecovery(mock_providers)
  663. # Generate a response with error recovery
  664. query = "What did Aristotle say about ethics?"
  665. response = await processor.generate_with_error_recovery(
  666. query=query,
  667. search_results=mock_search_results
  668. )
  669. # Verify both LLM calls were made
  670. assert mock_providers.llm.aget_completion.call_count == 2
  671. # Check the second call used the fallback model
  672. second_call_kwargs = mock_providers.llm.aget_completion.call_args_list[1][1]
  673. assert second_call_kwargs["model"] == "fallback_model"
  674. # Check the response is from the fallback
  675. assert response == "Fallback response after error"
  676. class TestRAGContextTruncation:
  677. """Tests for context truncation strategies in RAG."""
  678. def test_token_count_truncation(self, mock_search_results):
  679. """Test truncating context based on token count."""
  680. # Function to truncate context to max tokens
  681. def truncate_context_by_tokens(search_results, max_tokens=1000):
  682. # Simple token counting function (in real code, use a tokenizer)
  683. def estimate_tokens(text):
  684. # Rough approximation: 4 chars ~ 1 token
  685. return len(text) // 4
  686. context_items = []
  687. current_tokens = 0
  688. # Add chunks until we hit the token limit
  689. for result in search_results["chunk_search_results"]:
  690. chunk_text = result["text"]
  691. chunk_tokens = estimate_tokens(chunk_text)
  692. if current_tokens + chunk_tokens > max_tokens:
  693. # If this chunk would exceed the limit, stop
  694. break
  695. # Add this chunk and update token count
  696. context_items.append(chunk_text)
  697. current_tokens += chunk_tokens
  698. return "\n\n".join(context_items)
  699. # Truncate to a small token limit (should fit ~2-3 chunks)
  700. small_context = truncate_context_by_tokens(mock_search_results, max_tokens=50)
  701. # Check truncation
  702. chunk_count = small_context.count("search result")
  703. assert 1 <= chunk_count <= 3, "Should only include 1-3 chunks with small token limit"
  704. # Truncate with larger limit (should fit all chunks)
  705. large_context = truncate_context_by_tokens(mock_search_results, max_tokens=1000)
  706. large_chunk_count = large_context.count("search result")
  707. assert large_chunk_count == 5, "Should include all 5 chunks with large token limit"
  708. def test_score_threshold_truncation(self, mock_search_results):
  709. """Test truncating context based on relevance score threshold."""
  710. # Function to truncate context based on minimum score
  711. def truncate_context_by_score(search_results, min_score=0.7):
  712. context_items = []
  713. # Add chunks that meet the minimum score
  714. for result in search_results["chunk_search_results"]:
  715. if result.get("score", 0) >= min_score:
  716. context_items.append(result["text"])
  717. return "\n\n".join(context_items)
  718. # Truncate with high score threshold (should only include top results)
  719. high_threshold_context = truncate_context_by_score(mock_search_results, min_score=0.85)
  720. # Check truncation
  721. high_chunk_count = high_threshold_context.count("search result")
  722. assert high_chunk_count <= 3, "Should only include top chunks with high score threshold"
  723. # Truncate with low score threshold (should include most or all chunks)
  724. low_threshold_context = truncate_context_by_score(mock_search_results, min_score=0.7)
  725. low_chunk_count = low_threshold_context.count("search result")
  726. assert low_chunk_count >= 4, "Should include most chunks with low score threshold"
  727. def test_mixed_truncation_strategy(self, mock_search_results):
  728. """Test mixed truncation strategy combining token count and score."""
  729. # Function implementing mixed truncation strategy
  730. def mixed_truncation_strategy(search_results, max_tokens=1000, min_score=0.7):
  731. # First filter by score
  732. filtered_results = [r for r in search_results["chunk_search_results"]
  733. if r.get("score", 0) >= min_score]
  734. # Then truncate by tokens
  735. def estimate_tokens(text):
  736. return len(text) // 4
  737. context_items = []
  738. current_tokens = 0
  739. for result in filtered_results:
  740. chunk_text = result["text"]
  741. chunk_tokens = estimate_tokens(chunk_text)
  742. if current_tokens + chunk_tokens > max_tokens:
  743. break
  744. context_items.append(chunk_text)
  745. current_tokens += chunk_tokens
  746. return "\n\n".join(context_items)
  747. # Test the mixed strategy
  748. context = mixed_truncation_strategy(
  749. mock_search_results,
  750. max_tokens=50,
  751. min_score=0.8
  752. )
  753. # Check result
  754. chunk_count = context.count("search result")
  755. assert 1 <= chunk_count <= 3, "Mixed strategy should limit results appropriately"
  756. class TestAdvancedCitationHandling:
  757. """Tests for advanced citation handling in RAG."""
  758. @pytest.fixture
  759. def mock_citation_results(self):
  760. """Return mock search results with citation information."""
  761. results = {
  762. "chunk_search_results": [
  763. {
  764. "chunk_id": f"chunk-{i}",
  765. "document_id": f"doc-{i//2}",
  766. "text": f"This is search result {i} about Aristotle's philosophy.",
  767. "metadata": {
  768. "source": f"source-{i}",
  769. "title": f"Document {i//2}",
  770. "page": i+1,
  771. "citation_id": f"cite{i}",
  772. "authors": ["Author A", "Author B"] if i % 2 == 0 else ["Author C"]
  773. },
  774. "score": 0.95 - (i * 0.05),
  775. }
  776. for i in range(5)
  777. ]
  778. }
  779. return results
  780. def test_structured_citation_formatting(self, mock_citation_results):
  781. """Test formatting structured citations with academic format."""
  782. # Function to format structured citations
  783. def format_structured_citations(search_results):
  784. citations = {}
  785. # Extract citation information
  786. for result in search_results["chunk_search_results"]:
  787. citation_id = result.get("metadata", {}).get("citation_id")
  788. if not citation_id:
  789. continue
  790. # Skip if we've already processed this citation
  791. if citation_id in citations:
  792. continue
  793. # Extract metadata
  794. metadata = result.get("metadata", {})
  795. authors = metadata.get("authors", [])
  796. title = metadata.get("title", "Untitled")
  797. source = metadata.get("source", "Unknown source")
  798. page = metadata.get("page", None)
  799. # Format citation in academic style
  800. author_text = ", ".join(authors) if authors else "Unknown author"
  801. citation_text = f"{author_text}. \"{title}\". {source}"
  802. if page:
  803. citation_text += f", p. {page}"
  804. # Store the formatted citation
  805. citations[citation_id] = {
  806. "text": citation_text,
  807. "document_id": result.get("document_id"),
  808. "chunk_id": result.get("chunk_id")
  809. }
  810. return citations
  811. # Format citations
  812. citations = format_structured_citations(mock_citation_results)
  813. # Check formatting
  814. assert len(citations) == 5, "Should have 5 unique citations"
  815. assert "Author A, Author B" in citations["cite0"]["text"], "Should include authors"
  816. assert "Document 0" in citations["cite0"]["text"], "Should include title"
  817. assert "source-0" in citations["cite0"]["text"], "Should include source"
  818. assert "p. 1" in citations["cite0"]["text"], "Should include page number"
  819. def test_inline_citation_replacement(self, mock_citation_results):
  820. """Test replacing citation placeholders with actual citations."""
  821. # First format the context with citation placeholders
  822. def format_context_with_citations(search_results):
  823. context_items = []
  824. for i, result in enumerate(search_results["chunk_search_results"]):
  825. citation_id = result.get("metadata", {}).get("citation_id")
  826. text = result["text"]
  827. if citation_id:
  828. context_items.append(f"{text} [{citation_id}]")
  829. else:
  830. context_items.append(text)
  831. return "\n\n".join(context_items)
  832. # Function to replace citation placeholders in LLM response
  833. def replace_citation_placeholders(response_text, citation_metadata):
  834. # Simple regex-based replacement
  835. import re
  836. def citation_replacement(match):
  837. citation_id = match.group(1)
  838. if citation_id in citation_metadata:
  839. citation = citation_metadata[citation_id]
  840. authors = citation.get("authors", ["Unknown author"])
  841. year = citation.get("year", "n.d.")
  842. return f"({authors[0]} et al., {year})"
  843. return match.group(0) # Keep original if not found
  844. # Replace [citeX] format
  845. pattern = r'\[(cite\d+)\]'
  846. return re.sub(pattern, citation_replacement, response_text)
  847. # Create mock citation metadata
  848. citation_metadata = {
  849. f"cite{i}": {
  850. "authors": [f"Author {chr(65+i)}"] + (["et al."] if i % 2 == 0 else []),
  851. "year": 2020 + i,
  852. "title": f"Document {i//2}"
  853. }
  854. for i in range(5)
  855. }
  856. # Response with citation placeholders
  857. response_with_placeholders = (
  858. "Aristotle's ethics [cite0] focuses on virtue ethics. "
  859. "This contrasts with utilitarianism [cite2] which focuses on outcomes. "
  860. "Later philosophers [cite4] expanded on these ideas."
  861. )
  862. # Replace placeholders
  863. final_response = replace_citation_placeholders(response_with_placeholders, citation_metadata)
  864. # Check formatting
  865. assert "(Author A et al., 2020)" in final_response, "Author A citation should be in the response"
  866. assert "(Author C" in final_response, "Author C citation should be in the response"
  867. assert "(Author E" in final_response, "Author E citation should be in the response"
  868. assert "[cite0]" not in final_response, "Citation placeholder [cite0] should be replaced"
  869. assert "[cite2]" not in final_response, "Citation placeholder [cite2] should be replaced"
  870. assert "[cite4]" not in final_response, "Citation placeholder [cite4] should be replaced"
  871. def test_hybrid_citation_strategy(self, mock_citation_results):
  872. """Test hybrid citation strategy with footnotes and bibliography."""
  873. # Function to process text with hybrid citation strategy
  874. def process_with_hybrid_citations(response_text, citation_metadata):
  875. import re
  876. # Step 1: Replace inline citations with footnote numbers
  877. footnotes = []
  878. footnote_index = 1
  879. def footnote_replacement(match):
  880. nonlocal footnote_index
  881. citation_id = match.group(1)
  882. if citation_id in citation_metadata:
  883. # Add footnote
  884. citation = citation_metadata[citation_id]
  885. source = citation.get("source", "Unknown source")
  886. title = citation.get("title", "Untitled")
  887. authors = citation.get("authors", ["Unknown author"])
  888. author_text = ", ".join(authors)
  889. footnote = f"{footnote_index}. {author_text}. \"{title}\". {source}."
  890. footnotes.append(footnote)
  891. # Return footnote reference in text
  892. result = f"[{footnote_index}]"
  893. footnote_index += 1
  894. return result
  895. return match.group(0) # Keep original if not found
  896. # Replace [citeX] format with footnote numbers
  897. pattern = r'\[(cite\d+)\]'
  898. processed_text = re.sub(pattern, footnote_replacement, response_text)
  899. # Step 2: Add footnotes at the end
  900. if footnotes:
  901. processed_text += "\n\nFootnotes:\n" + "\n".join(footnotes)
  902. # Step 3: Add bibliography
  903. bibliography = []
  904. for citation_id, citation in citation_metadata.items():
  905. if any(f"[{citation_id}]" in response_text for citation_id in citation_metadata):
  906. source = citation.get("source", "Unknown source")
  907. title = citation.get("title", "Untitled")
  908. authors = citation.get("authors", ["Unknown author"])
  909. year = citation.get("year", "n.d.")
  910. bib_entry = f"{', '.join(authors)}. ({year}). \"{title}\". {source}."
  911. bibliography.append(bib_entry)
  912. if bibliography:
  913. processed_text += "\n\nBibliography:\n" + "\n".join(bibliography)
  914. return processed_text
  915. # Create mock citation metadata
  916. citation_metadata = {
  917. f"cite{i}": {
  918. "authors": [f"Author {chr(65+i)}"] + (["et al."] if i % 2 == 0 else []),
  919. "year": 2020 + i,
  920. "title": f"Document {i//2}",
  921. "source": f"Journal of Philosophy, Volume {i+1}"
  922. }
  923. for i in range(5)
  924. }
  925. # Response with citation placeholders
  926. response_with_placeholders = (
  927. "Aristotle's ethics [cite0] focuses on virtue ethics. "
  928. "This contrasts with utilitarianism [cite2] which focuses on outcomes. "
  929. "Later philosophers [cite4] expanded on these ideas."
  930. )
  931. # Apply hybrid citation processing
  932. final_response = process_with_hybrid_citations(response_with_placeholders, citation_metadata)
  933. # Check formatting
  934. assert "[1]" in final_response
  935. assert "[2]" in final_response
  936. assert "[3]" in final_response
  937. assert "Footnotes:" in final_response
  938. assert "Bibliography:" in final_response
  939. assert "Journal of Philosophy" in final_response
  940. assert "[cite0]" not in final_response
  941. assert "[cite2]" not in final_response
  942. assert "[cite4]" not in final_response
  943. class TestRAGRetrievalStrategies:
  944. """Tests for different retrieval strategies in RAG."""
  945. @pytest.mark.asyncio
  946. async def test_hybrid_search_strategy(self, mock_providers):
  947. """Test hybrid search combining keyword and semantic search."""
  948. # Mock search results
  949. keyword_results = {
  950. "chunk_search_results": [
  951. {
  952. "chunk_id": f"keyword-chunk-{i}",
  953. "document_id": f"doc-{i}",
  954. "text": f"Keyword match {i} about Aristotle's ethics.",
  955. "metadata": {"source": f"source-{i}"},
  956. "score": 0.95 - (i * 0.05),
  957. }
  958. for i in range(3)
  959. ]
  960. }
  961. semantic_results = {
  962. "chunk_search_results": [
  963. {
  964. "chunk_id": f"semantic-chunk-{i}",
  965. "document_id": f"doc-{i+5}",
  966. "text": f"Semantic match {i} about virtue ethics philosophy.",
  967. "metadata": {"source": f"source-{i+5}"},
  968. "score": 0.9 - (i * 0.05),
  969. }
  970. for i in range(3)
  971. ]
  972. }
  973. # Mock hybrid search function
  974. async def perform_hybrid_search(query, **kwargs):
  975. # Perform both search types
  976. # In real implementation, these would be actual search calls
  977. keyword_results_copy = keyword_results.copy()
  978. semantic_results_copy = semantic_results.copy()
  979. # Combine and deduplicate results
  980. combined_results = {
  981. "chunk_search_results":
  982. keyword_results_copy["chunk_search_results"][:2] +
  983. semantic_results_copy["chunk_search_results"][:2]
  984. }
  985. return combined_results
  986. # Mock RAG processor using hybrid search
  987. class HybridSearchRAGProcessor:
  988. def __init__(self, providers):
  989. self.providers = providers
  990. # Fix the prompt builder to include actual content
  991. self.prompt_builder = MagicMock()
  992. # Configure the prompt builder to actually include the search results in the prompt
  993. async def build_prompt_with_content(query, search_results, **kwargs):
  994. context = ""
  995. for result in search_results.get("chunk_search_results", []):
  996. context += f"{result.get('text', '')}\n\n"
  997. return [
  998. {"role": "system", "content": f"System prompt with hybrid context:\n\n{context}"},
  999. {"role": "user", "content": query}
  1000. ]
  1001. self.prompt_builder.build_prompt = AsyncMock(side_effect=build_prompt_with_content)
  1002. # Configure LLM to return a valid response
  1003. self.providers.llm.aget_completion = AsyncMock(return_value={
  1004. "choices": [{"message": {"content": "LLM generated response"}}]
  1005. })
  1006. async def generate_with_hybrid_search(self, query):
  1007. # Perform hybrid search
  1008. search_results = await perform_hybrid_search(query)
  1009. # Build prompt with combined results
  1010. messages = await self.prompt_builder.build_prompt(
  1011. query=query,
  1012. search_results=search_results
  1013. )
  1014. # Generate response
  1015. response = await self.providers.llm.aget_completion(messages=messages)
  1016. return response["choices"][0]["message"]["content"]
  1017. # Create processor and generate response
  1018. processor = HybridSearchRAGProcessor(mock_providers)
  1019. query = "What did Aristotle say about ethics?"
  1020. response = await processor.generate_with_hybrid_search(query)
  1021. # Check that the LLM was called with the hybrid search results
  1022. call_args = mock_providers.llm.aget_completion.call_args[1]
  1023. messages = call_args["messages"]
  1024. # Find the system message
  1025. system_message = next((m for m in messages if m["role"] == "system"), None)
  1026. # Verify both result types are in the context
  1027. assert "Keyword match" in system_message["content"], "System message should include keyword matches"
  1028. assert "Semantic match" in system_message["content"], "System message should include semantic matches"
  1029. # Check the final response
  1030. assert response == "LLM generated response", "Should return the mocked LLM response"
  1031. @pytest.mark.asyncio
  1032. async def test_reranking_strategy(self, mock_providers, mock_search_results):
  1033. """Test reranking search results before including in RAG context."""
  1034. # Define a reranker function
  1035. def rerank_results(search_results, query):
  1036. # This would use a model in real implementation
  1037. # Here we'll just simulate reranking with a simple heuristic
  1038. # Create a copy to avoid modifying the original
  1039. reranked_results = {"chunk_search_results": []}
  1040. # Apply a mock reranking logic
  1041. for result in search_results["chunk_search_results"]:
  1042. # Create a copy of the result
  1043. new_result = result.copy()
  1044. # Adjust score based on whether it contains keywords from query
  1045. keywords = ["ethics", "aristotle", "philosophy"]
  1046. score_adjustment = sum(0.1 for keyword in keywords
  1047. if keyword.lower() in new_result["text"].lower())
  1048. new_result["score"] = min(0.99, result.get("score", 0.5) + score_adjustment)
  1049. new_result["reranked"] = True
  1050. reranked_results["chunk_search_results"].append(new_result)
  1051. # Sort by adjusted score
  1052. reranked_results["chunk_search_results"].sort(
  1053. key=lambda x: x.get("score", 0),
  1054. reverse=True
  1055. )
  1056. return reranked_results
  1057. # Mock RAG processor with reranking
  1058. class RerankedRAGProcessor:
  1059. def __init__(self, providers):
  1060. self.providers = providers
  1061. self.prompt_builder = MagicMock()
  1062. self.prompt_builder.build_prompt = AsyncMock(
  1063. return_value=[
  1064. {"role": "system", "content": "System prompt with reranked context"},
  1065. {"role": "user", "content": "What did Aristotle say about ethics?"}
  1066. ]
  1067. )
  1068. async def generate_with_reranking(self, query, search_results):
  1069. # Rerank the search results
  1070. reranked_results = rerank_results(search_results, query)
  1071. # Build prompt with reranked results
  1072. messages = await self.prompt_builder.build_prompt(
  1073. query=query,
  1074. search_results=reranked_results
  1075. )
  1076. # Generate response
  1077. response = await self.providers.llm.aget_completion(messages=messages)
  1078. return response["choices"][0]["message"]["content"]
  1079. # Create processor
  1080. processor = RerankedRAGProcessor(mock_providers)
  1081. # Generate response with reranking
  1082. query = "What did Aristotle say about ethics?"
  1083. response = await processor.generate_with_reranking(query, mock_search_results)
  1084. # Verify the LLM was called
  1085. mock_providers.llm.aget_completion.assert_called_once()
  1086. # Check the response
  1087. assert response == "LLM generated response"