retrieval_service.py 83 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090
  1. import asyncio
  2. import json
  3. import logging
  4. from copy import deepcopy
  5. from datetime import datetime
  6. from typing import Any, AsyncGenerator, Literal, Optional
  7. from uuid import UUID
  8. from fastapi import HTTPException
  9. from core import (
  10. Citation,
  11. R2RRAGAgent,
  12. R2RStreamingRAGAgent,
  13. R2RStreamingResearchAgent,
  14. R2RXMLToolsRAGAgent,
  15. R2RXMLToolsResearchAgent,
  16. R2RXMLToolsStreamingRAGAgent,
  17. R2RXMLToolsStreamingResearchAgent,
  18. )
  19. from core.agent.research import R2RResearchAgent
  20. from core.base import (
  21. AggregateSearchResult,
  22. ChunkSearchResult,
  23. DocumentResponse,
  24. GenerationConfig,
  25. GraphCommunityResult,
  26. GraphEntityResult,
  27. GraphRelationshipResult,
  28. GraphSearchResult,
  29. GraphSearchResultType,
  30. IngestionStatus,
  31. Message,
  32. R2RException,
  33. SearchSettings,
  34. WebSearchResult,
  35. format_search_results_for_llm,
  36. )
  37. from core.base.agent.tools.registry import ToolRegistry
  38. from core.base.api.models import RAGResponse, User
  39. from core.utils import (
  40. CitationTracker,
  41. SearchResultsCollector,
  42. SSEFormatter,
  43. dump_collector,
  44. dump_obj,
  45. extract_citations,
  46. find_new_citation_spans,
  47. num_tokens_from_messages,
  48. )
  49. from shared.api.models.management.responses import MessageResponse
  50. from ..abstractions import R2RProviders
  51. from ..config import R2RConfig
  52. from .base import Service
  53. logger = logging.getLogger()
  54. class AgentFactory:
  55. """
  56. Factory class that creates appropriate agent instances based on mode,
  57. model type, and streaming preferences.
  58. """
  59. @staticmethod
  60. def create_agent(
  61. mode: Literal["rag", "research"],
  62. database_provider,
  63. llm_provider,
  64. config, # : AgentConfig
  65. search_settings, # : SearchSettings
  66. generation_config, #: GenerationConfig
  67. app_config, #: AppConfig
  68. knowledge_search_method,
  69. content_method,
  70. file_search_method,
  71. max_tool_context_length: int = 32_768,
  72. rag_tools: Optional[list[str]] = None,
  73. research_tools: Optional[list[str]] = None,
  74. tools: Optional[list[str]] = None, # For backward compatibility
  75. ):
  76. """
  77. Creates and returns the appropriate agent based on provided parameters.
  78. Args:
  79. mode: Either "rag" or "research" to determine agent type
  80. database_provider: Provider for database operations
  81. llm_provider: Provider for LLM operations
  82. config: Agent configuration
  83. search_settings: Search settings for retrieval
  84. generation_config: Generation configuration with LLM parameters
  85. app_config: Application configuration
  86. knowledge_search_method: Method for knowledge search
  87. content_method: Method for content retrieval
  88. file_search_method: Method for file search
  89. max_tool_context_length: Maximum context length for tools
  90. rag_tools: Tools specifically for RAG mode
  91. research_tools: Tools specifically for Research mode
  92. tools: Deprecated backward compatibility parameter
  93. Returns:
  94. An appropriate agent instance
  95. """
  96. # Create a deep copy of the config to avoid modifying the original
  97. agent_config = deepcopy(config)
  98. tool_registry = ToolRegistry()
  99. # Handle tool specifications based on mode
  100. if mode == "rag":
  101. # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults
  102. if rag_tools:
  103. agent_config.rag_tools = rag_tools
  104. elif tools: # Backward compatibility
  105. agent_config.rag_tools = tools
  106. # If neither was provided, the config's default rag_tools will be used
  107. elif mode == "research":
  108. # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults
  109. if research_tools:
  110. agent_config.research_tools = research_tools
  111. elif tools: # Backward compatibility
  112. agent_config.research_tools = tools
  113. # If neither was provided, the config's default research_tools will be used
  114. # Determine if we need XML-based tools based on model
  115. use_xml_format = False
  116. # if generation_config.model:
  117. # model_str = generation_config.model.lower()
  118. # use_xml_format = "deepseek" in model_str or "gemini" in model_str
  119. # Set streaming mode based on generation config
  120. is_streaming = generation_config.stream
  121. # Create the appropriate agent based on all factors
  122. if mode == "rag":
  123. # RAG mode agents
  124. if is_streaming:
  125. if use_xml_format:
  126. return R2RXMLToolsStreamingRAGAgent(
  127. database_provider=database_provider,
  128. llm_provider=llm_provider,
  129. config=agent_config,
  130. search_settings=search_settings,
  131. rag_generation_config=generation_config,
  132. max_tool_context_length=max_tool_context_length,
  133. knowledge_search_method=knowledge_search_method,
  134. content_method=content_method,
  135. file_search_method=file_search_method,
  136. )
  137. else:
  138. return R2RStreamingRAGAgent(
  139. database_provider=database_provider,
  140. llm_provider=llm_provider,
  141. config=agent_config,
  142. search_settings=search_settings,
  143. rag_generation_config=generation_config,
  144. max_tool_context_length=max_tool_context_length,
  145. knowledge_search_method=knowledge_search_method,
  146. content_method=content_method,
  147. file_search_method=file_search_method,
  148. tool_registry=tool_registry,
  149. )
  150. else:
  151. if use_xml_format:
  152. return R2RXMLToolsRAGAgent(
  153. database_provider=database_provider,
  154. llm_provider=llm_provider,
  155. config=agent_config,
  156. search_settings=search_settings,
  157. rag_generation_config=generation_config,
  158. max_tool_context_length=max_tool_context_length,
  159. knowledge_search_method=knowledge_search_method,
  160. content_method=content_method,
  161. file_search_method=file_search_method,
  162. tool_registry=tool_registry,
  163. )
  164. else:
  165. return R2RRAGAgent(
  166. database_provider=database_provider,
  167. llm_provider=llm_provider,
  168. config=agent_config,
  169. search_settings=search_settings,
  170. rag_generation_config=generation_config,
  171. max_tool_context_length=max_tool_context_length,
  172. knowledge_search_method=knowledge_search_method,
  173. content_method=content_method,
  174. file_search_method=file_search_method,
  175. tool_registry=tool_registry,
  176. )
  177. else:
  178. # Research mode agents
  179. if is_streaming:
  180. if use_xml_format:
  181. return R2RXMLToolsStreamingResearchAgent(
  182. app_config=app_config,
  183. database_provider=database_provider,
  184. llm_provider=llm_provider,
  185. config=agent_config,
  186. search_settings=search_settings,
  187. rag_generation_config=generation_config,
  188. max_tool_context_length=max_tool_context_length,
  189. knowledge_search_method=knowledge_search_method,
  190. content_method=content_method,
  191. file_search_method=file_search_method,
  192. )
  193. else:
  194. return R2RStreamingResearchAgent(
  195. app_config=app_config,
  196. database_provider=database_provider,
  197. llm_provider=llm_provider,
  198. config=agent_config,
  199. search_settings=search_settings,
  200. rag_generation_config=generation_config,
  201. max_tool_context_length=max_tool_context_length,
  202. knowledge_search_method=knowledge_search_method,
  203. content_method=content_method,
  204. file_search_method=file_search_method,
  205. )
  206. else:
  207. if use_xml_format:
  208. return R2RXMLToolsResearchAgent(
  209. app_config=app_config,
  210. database_provider=database_provider,
  211. llm_provider=llm_provider,
  212. config=agent_config,
  213. search_settings=search_settings,
  214. rag_generation_config=generation_config,
  215. max_tool_context_length=max_tool_context_length,
  216. knowledge_search_method=knowledge_search_method,
  217. content_method=content_method,
  218. file_search_method=file_search_method,
  219. )
  220. else:
  221. return R2RResearchAgent(
  222. app_config=app_config,
  223. database_provider=database_provider,
  224. llm_provider=llm_provider,
  225. config=agent_config,
  226. search_settings=search_settings,
  227. rag_generation_config=generation_config,
  228. max_tool_context_length=max_tool_context_length,
  229. knowledge_search_method=knowledge_search_method,
  230. content_method=content_method,
  231. file_search_method=file_search_method,
  232. )
  233. class RetrievalService(Service):
  234. def __init__(
  235. self,
  236. config: R2RConfig,
  237. providers: R2RProviders,
  238. ):
  239. super().__init__(
  240. config,
  241. providers,
  242. )
  243. async def search(
  244. self,
  245. query: str,
  246. search_settings: SearchSettings = SearchSettings(),
  247. *args,
  248. **kwargs,
  249. ) -> AggregateSearchResult:
  250. """
  251. Depending on search_settings.search_strategy, fan out
  252. to basic, hyde, or rag_fusion method. Each returns
  253. an AggregateSearchResult that includes chunk + graph results.
  254. """
  255. strategy = search_settings.search_strategy.lower()
  256. if strategy == "hyde":
  257. return await self._hyde_search(query, search_settings)
  258. elif strategy == "rag_fusion":
  259. return await self._rag_fusion_search(query, search_settings)
  260. else:
  261. # 'vanilla', 'basic', or anything else...
  262. return await self._basic_search(query, search_settings)
  263. async def _basic_search(
  264. self, query: str, search_settings: SearchSettings
  265. ) -> AggregateSearchResult:
  266. """
  267. 1) Possibly embed the query (if semantic or hybrid).
  268. 2) Chunk search.
  269. 3) Graph search.
  270. 4) Combine into an AggregateSearchResult.
  271. """
  272. # -- 1) Possibly embed the query
  273. query_vector = None
  274. if (
  275. search_settings.use_semantic_search
  276. or search_settings.use_hybrid_search
  277. ):
  278. query_vector = (
  279. await self.providers.completion_embedding.async_get_embedding(
  280. text=query
  281. )
  282. )
  283. # -- 2) Chunk search
  284. chunk_results = []
  285. if search_settings.chunk_settings.enabled:
  286. chunk_results = await self._vector_search_logic(
  287. query_text=query,
  288. search_settings=search_settings,
  289. precomputed_vector=query_vector, # Pass in the vector we just computed (if any)
  290. )
  291. # -- 3) Graph search
  292. graph_results = []
  293. if search_settings.graph_settings.enabled:
  294. graph_results = await self._graph_search_logic(
  295. query_text=query,
  296. search_settings=search_settings,
  297. precomputed_vector=query_vector, # same idea
  298. )
  299. # -- 4) Combine
  300. return AggregateSearchResult(
  301. chunk_search_results=chunk_results,
  302. graph_search_results=graph_results,
  303. )
  304. async def _rag_fusion_search(
  305. self, query: str, search_settings: SearchSettings
  306. ) -> AggregateSearchResult:
  307. """
  308. Implements 'RAG Fusion':
  309. 1) Generate N sub-queries from the user query
  310. 2) For each sub-query => do chunk & graph search
  311. 3) Combine / fuse all retrieved results using Reciprocal Rank Fusion
  312. 4) Return an AggregateSearchResult
  313. """
  314. # 1) Generate sub-queries from the user’s original query
  315. # Typically you want the original query to remain in the set as well,
  316. # so that we do not lose the exact user intent.
  317. sub_queries = [query]
  318. if search_settings.num_sub_queries > 1:
  319. # Generate (num_sub_queries - 1) rephrasings
  320. # (Or just generate exactly search_settings.num_sub_queries,
  321. # and remove the first if you prefer.)
  322. extra = await self._generate_similar_queries(
  323. query=query,
  324. num_sub_queries=search_settings.num_sub_queries - 1,
  325. )
  326. sub_queries.extend(extra)
  327. # 2) For each sub-query => do chunk + graph search
  328. # We’ll store them in a structure so we can fuse them.
  329. # chunk_results_list is a list of lists of ChunkSearchResult
  330. # graph_results_list is a list of lists of GraphSearchResult
  331. chunk_results_list = []
  332. graph_results_list = []
  333. for sq in sub_queries:
  334. # Recompute or reuse the embedding if desired
  335. # (You could do so, but not mandatory if you have a local approach)
  336. # chunk + graph search
  337. aggr = await self._basic_search(sq, search_settings)
  338. chunk_results_list.append(aggr.chunk_search_results)
  339. graph_results_list.append(aggr.graph_search_results)
  340. # 3) Fuse the chunk results and fuse the graph results.
  341. # We'll use a simple RRF approach: each sub-query's result list
  342. # is a ranking from best to worst.
  343. fused_chunk_results = self._reciprocal_rank_fusion_chunks( # type: ignore
  344. chunk_results_list # type: ignore
  345. )
  346. filtered_graph_results = [
  347. results for results in graph_results_list if results is not None
  348. ]
  349. fused_graph_results = self._reciprocal_rank_fusion_graphs(
  350. filtered_graph_results
  351. )
  352. # Optionally, after the RRF, you may want to do a final semantic re-rank
  353. # of the fused results by the user’s original query.
  354. # E.g.:
  355. if fused_chunk_results:
  356. fused_chunk_results = (
  357. await self.providers.completion_embedding.arerank(
  358. query=query,
  359. results=fused_chunk_results,
  360. limit=search_settings.limit,
  361. )
  362. )
  363. # Sort or slice the graph results if needed:
  364. if fused_graph_results and search_settings.include_scores:
  365. fused_graph_results.sort(
  366. key=lambda g: g.score if g.score is not None else 0.0,
  367. reverse=True,
  368. )
  369. fused_graph_results = fused_graph_results[: search_settings.limit]
  370. # 4) Return final AggregateSearchResult
  371. return AggregateSearchResult(
  372. chunk_search_results=fused_chunk_results,
  373. graph_search_results=fused_graph_results,
  374. )
  375. async def _generate_similar_queries(
  376. self, query: str, num_sub_queries: int = 2
  377. ) -> list[str]:
  378. """
  379. Use your LLM to produce 'similar' queries or rephrasings
  380. that might retrieve different but relevant documents.
  381. You can prompt your model with something like:
  382. "Given the user query, produce N alternative short queries that
  383. capture possible interpretations or expansions.
  384. Keep them relevant to the user's intent."
  385. """
  386. if num_sub_queries < 1:
  387. return []
  388. # In production, you'd fetch a prompt from your prompts DB:
  389. # Something like:
  390. prompt = f"""
  391. You are a helpful assistant. The user query is: "{query}"
  392. Generate {num_sub_queries} alternative search queries that capture
  393. slightly different phrasings or expansions while preserving the core meaning.
  394. Return each alternative on its own line.
  395. """
  396. # For a short generation, we can set minimal tokens
  397. gen_config = GenerationConfig(
  398. model=self.config.app.fast_llm,
  399. max_tokens=128,
  400. temperature=0.8,
  401. stream=False,
  402. )
  403. response = await self.providers.llm.aget_completion(
  404. messages=[{"role": "system", "content": prompt}],
  405. generation_config=gen_config,
  406. )
  407. raw_text = (
  408. response.choices[0].message.content.strip()
  409. if response.choices[0].message.content is not None
  410. else ""
  411. )
  412. # Suppose each line is a sub-query
  413. lines = [line.strip() for line in raw_text.split("\n") if line.strip()]
  414. return lines[:num_sub_queries]
  415. def _reciprocal_rank_fusion_chunks(
  416. self, list_of_rankings: list[list[ChunkSearchResult]], k: float = 60.0
  417. ) -> list[ChunkSearchResult]:
  418. """
  419. Simple RRF for chunk results.
  420. list_of_rankings is something like:
  421. [
  422. [chunkA, chunkB, chunkC], # sub-query #1, in order
  423. [chunkC, chunkD], # sub-query #2, in order
  424. ...
  425. ]
  426. We'll produce a dictionary mapping chunk.id -> aggregated_score,
  427. then sort descending.
  428. """
  429. if not list_of_rankings:
  430. return []
  431. # Build a map of chunk_id => final_rff_score
  432. score_map: dict[str, float] = {}
  433. # We also need to store a reference to the chunk object
  434. # (the "first" or "best" instance), so we can reconstruct them later
  435. chunk_map: dict[str, Any] = {}
  436. for ranking_list in list_of_rankings:
  437. for rank, chunk_result in enumerate(ranking_list, start=1):
  438. if not chunk_result.id:
  439. # fallback if no chunk_id is present
  440. continue
  441. c_id = chunk_result.id
  442. # RRF scoring
  443. # score = sum(1 / (k + rank)) for each sub-query ranking
  444. # We'll accumulate it.
  445. existing_score = score_map.get(str(c_id), 0.0)
  446. new_score = existing_score + 1.0 / (k + rank)
  447. score_map[str(c_id)] = new_score
  448. # Keep a reference to chunk
  449. if c_id not in chunk_map:
  450. chunk_map[str(c_id)] = chunk_result
  451. # Now sort by final score
  452. fused_items = sorted(
  453. score_map.items(), key=lambda x: x[1], reverse=True
  454. )
  455. # Rebuild the final list of chunk results with new 'score'
  456. fused_chunks = []
  457. for c_id, agg_score in fused_items: # type: ignore
  458. # copy the chunk
  459. c = chunk_map[str(c_id)]
  460. # Optionally store the RRF score if you want
  461. c.score = agg_score
  462. fused_chunks.append(c)
  463. return fused_chunks
  464. def _reciprocal_rank_fusion_graphs(
  465. self, list_of_rankings: list[list[GraphSearchResult]], k: float = 60.0
  466. ) -> list[GraphSearchResult]:
  467. """
  468. Similar RRF logic but for graph results.
  469. """
  470. if not list_of_rankings:
  471. return []
  472. score_map: dict[str, float] = {}
  473. graph_map = {}
  474. for ranking_list in list_of_rankings:
  475. for rank, g_result in enumerate(ranking_list, start=1):
  476. # We'll do a naive ID approach:
  477. # If your GraphSearchResult has a unique ID in g_result.content.id or so
  478. # we can use that as a key.
  479. # If not, you might have to build a key from the content.
  480. g_id = None
  481. if hasattr(g_result.content, "id"):
  482. g_id = str(g_result.content.id)
  483. else:
  484. # fallback
  485. g_id = f"graph_{hash(g_result.content.json())}"
  486. existing_score = score_map.get(g_id, 0.0)
  487. new_score = existing_score + 1.0 / (k + rank)
  488. score_map[g_id] = new_score
  489. if g_id not in graph_map:
  490. graph_map[g_id] = g_result
  491. # Sort descending by aggregated RRF score
  492. fused_items = sorted(
  493. score_map.items(), key=lambda x: x[1], reverse=True
  494. )
  495. fused_graphs = []
  496. for g_id, agg_score in fused_items:
  497. g = graph_map[g_id]
  498. g.score = agg_score
  499. fused_graphs.append(g)
  500. return fused_graphs
  501. async def _hyde_search(
  502. self, query: str, search_settings: SearchSettings
  503. ) -> AggregateSearchResult:
  504. """
  505. 1) Generate N hypothetical docs via LLM
  506. 2) For each doc => embed => parallel chunk search & graph search
  507. 3) Merge chunk results => optional re-rank => top K
  508. 4) Merge graph results => (optionally re-rank or keep them distinct)
  509. """
  510. # 1) Generate hypothetical docs
  511. hyde_docs = await self._run_hyde_generation(
  512. query=query, num_sub_queries=search_settings.num_sub_queries
  513. )
  514. chunk_all = []
  515. graph_all = []
  516. # We'll gather the per-doc searches in parallel
  517. tasks = []
  518. for hypothetical_text in hyde_docs:
  519. tasks.append(
  520. asyncio.create_task(
  521. self._fanout_chunk_and_graph_search(
  522. user_text=query, # The user’s original query
  523. alt_text=hypothetical_text, # The hypothetical doc
  524. search_settings=search_settings,
  525. )
  526. )
  527. )
  528. # 2) Wait for them all
  529. results_list = await asyncio.gather(*tasks)
  530. # each item in results_list is a tuple: (chunks, graphs)
  531. # Flatten chunk+graph results
  532. for c_results, g_results in results_list:
  533. chunk_all.extend(c_results)
  534. graph_all.extend(g_results)
  535. # 3) Re-rank chunk results with the original query
  536. if chunk_all:
  537. chunk_all = await self.providers.completion_embedding.arerank(
  538. query=query, # final user query
  539. results=chunk_all,
  540. limit=int(
  541. search_settings.limit * search_settings.num_sub_queries
  542. ),
  543. # no limit on results - limit=search_settings.limit,
  544. )
  545. # 4) If needed, re-rank graph results or just slice top-K by score
  546. if search_settings.include_scores and graph_all:
  547. graph_all.sort(key=lambda g: g.score or 0.0, reverse=True)
  548. graph_all = (
  549. graph_all # no limit on results - [: search_settings.limit]
  550. )
  551. return AggregateSearchResult(
  552. chunk_search_results=chunk_all,
  553. graph_search_results=graph_all,
  554. )
  555. async def _fanout_chunk_and_graph_search(
  556. self,
  557. user_text: str,
  558. alt_text: str,
  559. search_settings: SearchSettings,
  560. ) -> tuple[list[ChunkSearchResult], list[GraphSearchResult]]:
  561. """
  562. 1) embed alt_text (HyDE doc or sub-query, etc.)
  563. 2) chunk search + graph search with that embedding
  564. """
  565. # Precompute the embedding of alt_text
  566. vec = await self.providers.completion_embedding.async_get_embedding(
  567. text=alt_text
  568. )
  569. # chunk search
  570. chunk_results = []
  571. if search_settings.chunk_settings.enabled:
  572. chunk_results = await self._vector_search_logic(
  573. query_text=user_text, # used for text-based stuff & re-ranking
  574. search_settings=search_settings,
  575. precomputed_vector=vec, # use the alt_text vector for semantic/hybrid
  576. )
  577. # graph search
  578. graph_results = []
  579. if search_settings.graph_settings.enabled:
  580. graph_results = await self._graph_search_logic(
  581. query_text=user_text, # or alt_text if you prefer
  582. search_settings=search_settings,
  583. precomputed_vector=vec,
  584. )
  585. return (chunk_results, graph_results)
  586. async def _vector_search_logic(
  587. self,
  588. query_text: str,
  589. search_settings: SearchSettings,
  590. precomputed_vector: Optional[list[float]] = None,
  591. ) -> list[ChunkSearchResult]:
  592. """
  593. • If precomputed_vector is given, use it for semantic/hybrid search.
  594. Otherwise embed query_text ourselves.
  595. • Then do fulltext, semantic, or hybrid search.
  596. • Optionally re-rank and return results.
  597. """
  598. if not search_settings.chunk_settings.enabled:
  599. return []
  600. # 1) Possibly embed
  601. query_vector = precomputed_vector
  602. if query_vector is None and (
  603. search_settings.use_semantic_search
  604. or search_settings.use_hybrid_search
  605. ):
  606. query_vector = (
  607. await self.providers.completion_embedding.async_get_embedding(
  608. text=query_text
  609. )
  610. )
  611. # 2) Choose which search to run
  612. if (
  613. search_settings.use_fulltext_search
  614. and search_settings.use_semantic_search
  615. ) or search_settings.use_hybrid_search:
  616. if query_vector is None:
  617. raise ValueError("Hybrid search requires a precomputed vector")
  618. raw_results = (
  619. await self.providers.database.chunks_handler.hybrid_search(
  620. query_vector=query_vector,
  621. query_text=query_text,
  622. search_settings=search_settings,
  623. )
  624. )
  625. elif search_settings.use_fulltext_search:
  626. raw_results = (
  627. await self.providers.database.chunks_handler.full_text_search(
  628. query_text=query_text,
  629. search_settings=search_settings,
  630. )
  631. )
  632. elif search_settings.use_semantic_search:
  633. if query_vector is None:
  634. raise ValueError(
  635. "Semantic search requires a precomputed vector"
  636. )
  637. raw_results = (
  638. await self.providers.database.chunks_handler.semantic_search(
  639. query_vector=query_vector,
  640. search_settings=search_settings,
  641. )
  642. )
  643. else:
  644. raise ValueError(
  645. "At least one of use_fulltext_search or use_semantic_search must be True"
  646. )
  647. # 3) Re-rank
  648. reranked = await self.providers.completion_embedding.arerank(
  649. query=query_text, results=raw_results, limit=search_settings.limit
  650. )
  651. # 4) Possibly augment text or metadata
  652. final_results = []
  653. for r in reranked:
  654. if "title" in r.metadata and search_settings.include_metadatas:
  655. title = r.metadata["title"]
  656. r.text = f"Document Title: {title}\n\nText: {r.text}"
  657. r.metadata["associated_query"] = query_text
  658. final_results.append(r)
  659. return final_results
  660. async def _graph_search_logic(
  661. self,
  662. query_text: str,
  663. search_settings: SearchSettings,
  664. precomputed_vector: Optional[list[float]] = None,
  665. ) -> list[GraphSearchResult]:
  666. """
  667. Mirrors your previous GraphSearch approach:
  668. • if precomputed_vector is supplied, use that
  669. • otherwise embed query_text
  670. • search entities, relationships, communities
  671. • return results
  672. """
  673. results: list[GraphSearchResult] = []
  674. if not search_settings.graph_settings.enabled:
  675. return results
  676. # 1) Possibly embed
  677. query_embedding = precomputed_vector
  678. if query_embedding is None:
  679. query_embedding = (
  680. await self.providers.completion_embedding.async_get_embedding(
  681. query_text
  682. )
  683. )
  684. base_limit = search_settings.limit
  685. graph_limits = search_settings.graph_settings.limits or {}
  686. # Entity search
  687. entity_limit = graph_limits.get("entities", base_limit)
  688. entity_cursor = self.providers.database.graphs_handler.graph_search(
  689. query_text,
  690. search_type="entities",
  691. limit=entity_limit,
  692. query_embedding=query_embedding,
  693. property_names=["name", "description", "id"],
  694. filters=search_settings.filters,
  695. )
  696. async for ent in entity_cursor:
  697. score = ent.get("similarity_score")
  698. metadata = ent.get("metadata", {})
  699. if isinstance(metadata, str):
  700. try:
  701. metadata = json.loads(metadata)
  702. except Exception as e:
  703. pass
  704. results.append(
  705. GraphSearchResult(
  706. id=ent.get("id", None),
  707. content=GraphEntityResult(
  708. name=ent.get("name", ""),
  709. description=ent.get("description", ""),
  710. id=ent.get("id", None),
  711. ),
  712. result_type=GraphSearchResultType.ENTITY,
  713. score=score if search_settings.include_scores else None,
  714. metadata=(
  715. {
  716. **(metadata or {}),
  717. "associated_query": query_text,
  718. }
  719. if search_settings.include_metadatas
  720. else {}
  721. ),
  722. )
  723. )
  724. # Relationship search
  725. rel_limit = graph_limits.get("relationships", base_limit)
  726. rel_cursor = self.providers.database.graphs_handler.graph_search(
  727. query_text,
  728. search_type="relationships",
  729. limit=rel_limit,
  730. query_embedding=query_embedding,
  731. property_names=[
  732. "id",
  733. "subject",
  734. "predicate",
  735. "object",
  736. "description",
  737. "subject_id",
  738. "object_id",
  739. ],
  740. filters=search_settings.filters,
  741. )
  742. async for rel in rel_cursor:
  743. score = rel.get("similarity_score")
  744. metadata = rel.get("metadata", {})
  745. if isinstance(metadata, str):
  746. try:
  747. metadata = json.loads(metadata)
  748. except Exception as e:
  749. pass
  750. results.append(
  751. GraphSearchResult(
  752. id=ent.get("id", None),
  753. content=GraphRelationshipResult(
  754. id=rel.get("id", None),
  755. subject=rel.get("subject", ""),
  756. predicate=rel.get("predicate", ""),
  757. object=rel.get("object", ""),
  758. subject_id=rel.get("subject_id", None),
  759. object_id=rel.get("object_id", None),
  760. description=rel.get("description", ""),
  761. ),
  762. result_type=GraphSearchResultType.RELATIONSHIP,
  763. score=score if search_settings.include_scores else None,
  764. metadata=(
  765. {
  766. **(metadata or {}),
  767. "associated_query": query_text,
  768. }
  769. if search_settings.include_metadatas
  770. else {}
  771. ),
  772. )
  773. )
  774. # Community search
  775. comm_limit = graph_limits.get("communities", base_limit)
  776. comm_cursor = self.providers.database.graphs_handler.graph_search(
  777. query_text,
  778. search_type="communities",
  779. limit=comm_limit,
  780. query_embedding=query_embedding,
  781. property_names=[
  782. "id",
  783. "name",
  784. "summary",
  785. ],
  786. filters=search_settings.filters,
  787. )
  788. async for comm in comm_cursor:
  789. score = comm.get("similarity_score")
  790. metadata = comm.get("metadata", {})
  791. if isinstance(metadata, str):
  792. try:
  793. metadata = json.loads(metadata)
  794. except Exception as e:
  795. pass
  796. results.append(
  797. GraphSearchResult(
  798. id=ent.get("id", None),
  799. content=GraphCommunityResult(
  800. id=comm.get("id", None),
  801. name=comm.get("name", ""),
  802. summary=comm.get("summary", ""),
  803. ),
  804. result_type=GraphSearchResultType.COMMUNITY,
  805. score=score if search_settings.include_scores else None,
  806. metadata=(
  807. {
  808. **(metadata or {}),
  809. "associated_query": query_text,
  810. }
  811. if search_settings.include_metadatas
  812. else {}
  813. ),
  814. )
  815. )
  816. return results
  817. async def _run_hyde_generation(
  818. self,
  819. query: str,
  820. num_sub_queries: int = 2,
  821. ) -> list[str]:
  822. """
  823. Calls the LLM with a 'HyDE' style prompt to produce multiple
  824. hypothetical documents/answers, one per line or separated by blank lines.
  825. """
  826. # Retrieve the prompt template from your database or config:
  827. # e.g. your "hyde" prompt has placeholders: {message}, {num_outputs}
  828. hyde_template = (
  829. await self.providers.database.prompts_handler.get_cached_prompt(
  830. prompt_name="hyde",
  831. inputs={"message": query, "num_outputs": num_sub_queries},
  832. )
  833. )
  834. # Now call the LLM with that as the system or user prompt:
  835. completion_config = GenerationConfig(
  836. model=self.config.app.fast_llm, # or whichever short/cheap model
  837. max_tokens=512,
  838. temperature=0.7,
  839. stream=False,
  840. )
  841. response = await self.providers.llm.aget_completion(
  842. messages=[{"role": "system", "content": hyde_template}],
  843. generation_config=completion_config,
  844. )
  845. # Suppose the LLM returns something like:
  846. #
  847. # "Doc1. Some made up text.\n\nDoc2. Another made up text.\n\n"
  848. #
  849. # So we split by double-newline or some pattern:
  850. raw_text = response.choices[0].message.content
  851. return [
  852. chunk.strip()
  853. for chunk in (raw_text or "").split("\n\n")
  854. if chunk.strip()
  855. ]
  856. async def search_documents(
  857. self,
  858. query: str,
  859. settings: SearchSettings,
  860. query_embedding: Optional[list[float]] = None,
  861. ) -> list[DocumentResponse]:
  862. if query_embedding is None:
  863. query_embedding = (
  864. await self.providers.completion_embedding.async_get_embedding(
  865. query
  866. )
  867. )
  868. return (
  869. await self.providers.database.documents_handler.search_documents(
  870. query_text=query,
  871. settings=settings,
  872. query_embedding=query_embedding,
  873. )
  874. )
  875. async def completion(
  876. self,
  877. messages: list[dict],
  878. generation_config: GenerationConfig,
  879. *args,
  880. **kwargs,
  881. ):
  882. return await self.providers.llm.aget_completion(
  883. [message.to_dict() for message in messages], # type: ignore
  884. generation_config,
  885. *args,
  886. **kwargs,
  887. )
  888. async def embedding(
  889. self,
  890. text: str,
  891. ):
  892. return await self.providers.completion_embedding.async_get_embedding(
  893. text=text
  894. )
  895. async def rag(
  896. self,
  897. query: str,
  898. rag_generation_config: GenerationConfig,
  899. search_settings: SearchSettings = SearchSettings(),
  900. system_prompt_name: str | None = None,
  901. task_prompt_name: str | None = None,
  902. include_web_search: bool = False,
  903. **kwargs,
  904. ) -> Any:
  905. """
  906. A single RAG method that can do EITHER a one-shot synchronous RAG or
  907. streaming SSE-based RAG, depending on rag_generation_config.stream.
  908. 1) Perform aggregator search => context
  909. 2) Build system+task prompts => messages
  910. 3) If not streaming => normal LLM call => return RAGResponse
  911. 4) If streaming => return an async generator of SSE lines
  912. """
  913. # 1) Possibly fix up any UUID filters in search_settings
  914. for f, val in list(search_settings.filters.items()):
  915. if isinstance(val, UUID):
  916. search_settings.filters[f] = str(val)
  917. try:
  918. # 2) Perform search => aggregated_results
  919. aggregated_results = await self.search(query, search_settings)
  920. # 3) Optionally add web search results if flag is enabled
  921. if include_web_search:
  922. web_results = await self._perform_web_search(query)
  923. # Merge web search results with existing aggregated results
  924. if web_results and web_results.web_search_results:
  925. if not aggregated_results.web_search_results:
  926. aggregated_results.web_search_results = (
  927. web_results.web_search_results
  928. )
  929. else:
  930. aggregated_results.web_search_results.extend(
  931. web_results.web_search_results
  932. )
  933. # 3) Build context from aggregator
  934. collector = SearchResultsCollector()
  935. collector.add_aggregate_result(aggregated_results)
  936. context_str = format_search_results_for_llm(aggregated_results)
  937. # 4) Prepare system+task messages
  938. system_prompt_name = system_prompt_name or "system"
  939. task_prompt_name = task_prompt_name or "rag"
  940. task_prompt = kwargs.get("task_prompt")
  941. messages = await self.providers.database.prompts_handler.get_message_payload(
  942. system_prompt_name=system_prompt_name,
  943. task_prompt_name=task_prompt_name,
  944. task_inputs={"query": query, "context": context_str},
  945. task_prompt=task_prompt,
  946. )
  947. # 5) Check streaming vs. non-streaming
  948. if not rag_generation_config.stream:
  949. # ========== Non-Streaming Logic ==========
  950. response = await self.providers.llm.aget_completion(
  951. messages=messages,
  952. generation_config=rag_generation_config,
  953. )
  954. llm_text = response.choices[0].message.content
  955. # (a) Extract short-ID references from final text
  956. raw_sids = extract_citations(llm_text or "")
  957. # (b) Possibly prune large content out of metadata
  958. metadata = response.dict()
  959. if "choices" in metadata and len(metadata["choices"]) > 0:
  960. metadata["choices"][0]["message"].pop("content", None)
  961. # (c) Build final RAGResponse
  962. rag_resp = RAGResponse(
  963. generated_answer=llm_text or "",
  964. search_results=aggregated_results,
  965. citations=[
  966. Citation(
  967. id=f"{sid}",
  968. object="citation",
  969. payload=dump_obj( # type: ignore
  970. self._find_item_by_shortid(sid, collector)
  971. ),
  972. )
  973. for sid in raw_sids
  974. ],
  975. metadata=metadata,
  976. completion=llm_text or "",
  977. )
  978. return rag_resp
  979. else:
  980. # ========== Streaming SSE Logic ==========
  981. async def sse_generator() -> AsyncGenerator[str, None]:
  982. # 1) Emit search results via SSEFormatter
  983. async for line in SSEFormatter.yield_search_results_event(
  984. aggregated_results
  985. ):
  986. yield line
  987. # Initialize citation tracker to manage citation state
  988. citation_tracker = CitationTracker()
  989. # Store citation payloads by ID for reuse
  990. citation_payloads = {}
  991. partial_text_buffer = ""
  992. # Begin streaming from the LLM
  993. msg_stream = self.providers.llm.aget_completion_stream(
  994. messages=messages,
  995. generation_config=rag_generation_config,
  996. )
  997. try:
  998. async for chunk in msg_stream:
  999. delta = chunk.choices[0].delta
  1000. finish_reason = chunk.choices[0].finish_reason
  1001. # if delta.thinking:
  1002. # check if delta has `thinking` attribute
  1003. if hasattr(delta, "thinking") and delta.thinking:
  1004. # Emit SSE "thinking" event
  1005. async for (
  1006. line
  1007. ) in SSEFormatter.yield_thinking_event(
  1008. delta.thinking
  1009. ):
  1010. yield line
  1011. if delta.content:
  1012. # (b) Emit SSE "message" event for this chunk of text
  1013. async for (
  1014. line
  1015. ) in SSEFormatter.yield_message_event(
  1016. delta.content
  1017. ):
  1018. yield line
  1019. # Accumulate new text
  1020. partial_text_buffer += delta.content
  1021. # (a) Extract citations from updated buffer
  1022. # For each *new* short ID, emit an SSE "citation" event
  1023. # Find new citation spans in the accumulated text
  1024. new_citation_spans = find_new_citation_spans(
  1025. partial_text_buffer, citation_tracker
  1026. )
  1027. # Process each new citation span
  1028. for cid, spans in new_citation_spans.items():
  1029. for span in spans:
  1030. # Check if this is the first time we've seen this citation ID
  1031. is_new_citation = (
  1032. citation_tracker.is_new_citation(
  1033. cid
  1034. )
  1035. )
  1036. # Get payload if it's a new citation
  1037. payload = None
  1038. if is_new_citation:
  1039. source_obj = (
  1040. self._find_item_by_shortid(
  1041. cid, collector
  1042. )
  1043. )
  1044. if source_obj:
  1045. # Store payload for reuse
  1046. payload = dump_obj(source_obj)
  1047. citation_payloads[cid] = (
  1048. payload
  1049. )
  1050. # Create citation event payload
  1051. citation_data = {
  1052. "id": cid,
  1053. "object": "citation",
  1054. "is_new": is_new_citation,
  1055. "span": {
  1056. "start": span[0],
  1057. "end": span[1],
  1058. },
  1059. }
  1060. # Only include full payload for new citations
  1061. if is_new_citation and payload:
  1062. citation_data["payload"] = payload
  1063. # Emit the citation event
  1064. async for (
  1065. line
  1066. ) in SSEFormatter.yield_citation_event(
  1067. citation_data
  1068. ):
  1069. yield line
  1070. # If the LLM signals it’s done
  1071. if finish_reason == "stop":
  1072. # Prepare consolidated citations for final answer event
  1073. consolidated_citations = []
  1074. # Group citations by ID with all their spans
  1075. for (
  1076. cid,
  1077. spans,
  1078. ) in citation_tracker.get_all_spans().items():
  1079. if cid in citation_payloads:
  1080. consolidated_citations.append(
  1081. {
  1082. "id": cid,
  1083. "object": "citation",
  1084. "spans": [
  1085. {
  1086. "start": s[0],
  1087. "end": s[1],
  1088. }
  1089. for s in spans
  1090. ],
  1091. "payload": citation_payloads[
  1092. cid
  1093. ],
  1094. }
  1095. )
  1096. # (c) Emit final answer + all collected citations
  1097. final_answer_evt = {
  1098. "id": "msg_final",
  1099. "object": "rag.final_answer",
  1100. "generated_answer": partial_text_buffer,
  1101. "citations": consolidated_citations,
  1102. }
  1103. async for (
  1104. line
  1105. ) in SSEFormatter.yield_final_answer_event(
  1106. final_answer_evt
  1107. ):
  1108. yield line
  1109. # (d) Signal the end of the SSE stream
  1110. yield SSEFormatter.yield_done_event()
  1111. break
  1112. except Exception as e:
  1113. logger.error(f"Error streaming LLM in rag: {e}")
  1114. # Optionally yield an SSE "error" event or handle differently
  1115. raise
  1116. return sse_generator()
  1117. except Exception as e:
  1118. logger.exception(f"Error in RAG pipeline: {e}")
  1119. if "NoneType" in str(e):
  1120. raise HTTPException(
  1121. status_code=502,
  1122. detail="Server not reachable or returned an invalid response",
  1123. ) from e
  1124. raise HTTPException(
  1125. status_code=500,
  1126. detail=f"Internal RAG Error - {str(e)}",
  1127. ) from e
  1128. def _find_item_by_shortid(
  1129. self, sid: str, collector: SearchResultsCollector
  1130. ) -> Optional[tuple[str, Any, int]]:
  1131. """
  1132. Example helper that tries to match aggregator items by short ID,
  1133. meaning result_obj.id starts with sid.
  1134. """
  1135. for source_type, result_obj in collector.get_all_results():
  1136. # if the aggregator item has an 'id' attribute
  1137. if getattr(result_obj, "id", None) is not None:
  1138. full_id_str = str(result_obj.id)
  1139. if full_id_str.startswith(sid):
  1140. if source_type == "chunk":
  1141. return (
  1142. result_obj.as_dict()
  1143. ) # (source_type, result_obj.as_dict())
  1144. else:
  1145. return result_obj # (source_type, result_obj)
  1146. return None
  1147. async def agent(
  1148. self,
  1149. rag_generation_config: GenerationConfig,
  1150. rag_tools: Optional[list[str]] = None,
  1151. tools: Optional[list[str]] = None, # backward compatibility
  1152. search_settings: SearchSettings = SearchSettings(),
  1153. task_prompt: Optional[str] = None,
  1154. include_title_if_available: Optional[bool] = False,
  1155. conversation_id: Optional[UUID] = None,
  1156. message: Optional[Message] = None,
  1157. messages: Optional[list[Message]] = None,
  1158. use_system_context: bool = False,
  1159. max_tool_context_length: int = 32_768,
  1160. research_tools: Optional[list[str]] = None,
  1161. research_generation_config: Optional[GenerationConfig] = None,
  1162. needs_initial_conversation_name: Optional[bool] = None,
  1163. mode: Optional[Literal["rag", "research"]] = "rag",
  1164. ):
  1165. """
  1166. Engage with an intelligent agent for information retrieval, analysis, and research.
  1167. Args:
  1168. rag_generation_config: Configuration for RAG mode generation
  1169. search_settings: Search configuration for retrieving context
  1170. task_prompt: Optional custom prompt override
  1171. include_title_if_available: Whether to include document titles
  1172. conversation_id: Optional conversation ID for continuity
  1173. message: Current message to process
  1174. messages: List of messages (deprecated)
  1175. use_system_context: Whether to use extended prompt
  1176. max_tool_context_length: Maximum context length for tools
  1177. rag_tools: List of tools for RAG mode
  1178. research_tools: List of tools for Research mode
  1179. research_generation_config: Configuration for Research mode generation
  1180. mode: Either "rag" or "research"
  1181. Returns:
  1182. Agent response with messages and conversation ID
  1183. """
  1184. try:
  1185. # Validate message inputs
  1186. if message and messages:
  1187. raise R2RException(
  1188. status_code=400,
  1189. message="Only one of message or messages should be provided",
  1190. )
  1191. if not message and not messages:
  1192. raise R2RException(
  1193. status_code=400,
  1194. message="Either message or messages should be provided",
  1195. )
  1196. # Ensure 'message' is a Message instance
  1197. if message and not isinstance(message, Message):
  1198. if isinstance(message, dict):
  1199. message = Message.from_dict(message)
  1200. else:
  1201. raise R2RException(
  1202. status_code=400,
  1203. message="""
  1204. Invalid message format. The expected format contains:
  1205. role: MessageType | 'system' | 'user' | 'assistant' | 'function'
  1206. content: Optional[str]
  1207. name: Optional[str]
  1208. function_call: Optional[dict[str, Any]]
  1209. tool_calls: Optional[list[dict[str, Any]]]
  1210. """,
  1211. )
  1212. # Ensure 'messages' is a list of Message instances
  1213. if messages:
  1214. processed_messages = []
  1215. for msg in messages:
  1216. if isinstance(msg, Message):
  1217. processed_messages.append(msg)
  1218. elif hasattr(msg, "dict"):
  1219. processed_messages.append(
  1220. Message.from_dict(msg.dict())
  1221. )
  1222. elif isinstance(msg, dict):
  1223. processed_messages.append(Message.from_dict(msg))
  1224. else:
  1225. processed_messages.append(Message.from_dict(str(msg)))
  1226. messages = processed_messages
  1227. else:
  1228. messages = []
  1229. # Validate and process mode-specific configurations
  1230. if mode == "rag" and research_tools:
  1231. logger.warning(
  1232. "research_tools provided but mode is 'rag'. These tools will be ignored."
  1233. )
  1234. research_tools = None
  1235. # Determine effective generation config based on mode
  1236. effective_generation_config = rag_generation_config
  1237. if mode == "research" and research_generation_config:
  1238. effective_generation_config = research_generation_config
  1239. # Set appropriate LLM model based on mode if not explicitly specified
  1240. if "model" not in effective_generation_config.model_fields_set:
  1241. if mode == "rag":
  1242. effective_generation_config.model = (
  1243. self.config.app.quality_llm
  1244. )
  1245. elif mode == "research":
  1246. effective_generation_config.model = (
  1247. self.config.app.planning_llm
  1248. )
  1249. # Transform UUID filters to strings
  1250. for filter_key, value in search_settings.filters.items():
  1251. if isinstance(value, UUID):
  1252. search_settings.filters[filter_key] = str(value)
  1253. # Process conversation data
  1254. ids = []
  1255. if conversation_id: # Fetch the existing conversation
  1256. try:
  1257. conversation_messages = await self.providers.database.conversations_handler.get_conversation(
  1258. conversation_id=conversation_id,
  1259. )
  1260. if needs_initial_conversation_name is None:
  1261. overview = await self.providers.database.conversations_handler.get_conversations_overview(
  1262. offset=0,
  1263. limit=1,
  1264. conversation_ids=[conversation_id],
  1265. )
  1266. if overview.get("total_entries", 0) > 0:
  1267. needs_initial_conversation_name = (
  1268. overview.get("results")[0].get("name") is None # type: ignore
  1269. )
  1270. except Exception as e:
  1271. logger.error(f"Error fetching conversation: {str(e)}")
  1272. if conversation_messages is not None:
  1273. messages_from_conversation: list[Message] = []
  1274. for message_response in conversation_messages:
  1275. if isinstance(message_response, MessageResponse):
  1276. messages_from_conversation.append(
  1277. message_response.message
  1278. )
  1279. ids.append(message_response.id)
  1280. else:
  1281. logger.warning(
  1282. f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
  1283. )
  1284. messages = messages_from_conversation + messages
  1285. else: # Create new conversation
  1286. conversation_response = await self.providers.database.conversations_handler.create_conversation()
  1287. conversation_id = conversation_response.id
  1288. needs_initial_conversation_name = True
  1289. if message:
  1290. messages.append(message)
  1291. if not messages:
  1292. raise R2RException(
  1293. status_code=400,
  1294. message="No messages to process",
  1295. )
  1296. current_message = messages[-1]
  1297. logger.debug(
  1298. f"Running the agent with conversation_id = {conversation_id} and message = {current_message}"
  1299. )
  1300. # Save the new message to the conversation
  1301. parent_id = ids[-1] if ids else None
  1302. message_response = await self.providers.database.conversations_handler.add_message(
  1303. conversation_id=conversation_id,
  1304. content=current_message,
  1305. parent_id=parent_id,
  1306. )
  1307. message_id = (
  1308. message_response.id if message_response is not None else None
  1309. )
  1310. # Extract filter information from search settings
  1311. filter_user_id, filter_collection_ids = (
  1312. self._parse_user_and_collection_filters(
  1313. search_settings.filters
  1314. )
  1315. )
  1316. # Validate system instruction configuration
  1317. if use_system_context and task_prompt:
  1318. raise R2RException(
  1319. status_code=400,
  1320. message="Both use_system_context and task_prompt cannot be True at the same time",
  1321. )
  1322. # Build the system instruction
  1323. if task_prompt:
  1324. system_instruction = task_prompt
  1325. else:
  1326. system_instruction = (
  1327. await self._build_aware_system_instruction(
  1328. max_tool_context_length=max_tool_context_length,
  1329. filter_user_id=filter_user_id,
  1330. filter_collection_ids=filter_collection_ids,
  1331. model=effective_generation_config.model,
  1332. use_system_context=use_system_context,
  1333. mode=mode,
  1334. )
  1335. )
  1336. # Configure agent with appropriate tools
  1337. agent_config = deepcopy(self.config.agent)
  1338. if mode == "rag":
  1339. # Use provided RAG tools or default from config
  1340. agent_config.rag_tools = (
  1341. rag_tools or tools or self.config.agent.rag_tools
  1342. )
  1343. else: # research mode
  1344. # Use provided Research tools or default from config
  1345. agent_config.research_tools = (
  1346. research_tools or tools or self.config.agent.research_tools
  1347. )
  1348. # Create the agent using our factory
  1349. mode = mode or "rag"
  1350. for msg in messages:
  1351. if msg.content is None:
  1352. msg.content = ""
  1353. agent = AgentFactory.create_agent(
  1354. mode=mode,
  1355. database_provider=self.providers.database,
  1356. llm_provider=self.providers.llm,
  1357. config=agent_config,
  1358. search_settings=search_settings,
  1359. generation_config=effective_generation_config,
  1360. app_config=self.config.app,
  1361. knowledge_search_method=self.search,
  1362. content_method=self.get_context,
  1363. file_search_method=self.search_documents,
  1364. max_tool_context_length=max_tool_context_length,
  1365. rag_tools=rag_tools,
  1366. research_tools=research_tools,
  1367. tools=tools, # Backward compatibility
  1368. )
  1369. # Handle streaming vs. non-streaming response
  1370. if effective_generation_config.stream:
  1371. async def stream_response():
  1372. try:
  1373. async for chunk in agent.arun(
  1374. messages=messages,
  1375. system_instruction=system_instruction,
  1376. include_title_if_available=include_title_if_available,
  1377. ):
  1378. yield chunk
  1379. except Exception as e:
  1380. logger.error(f"Error streaming agent output: {e}")
  1381. raise e
  1382. finally:
  1383. # Persist conversation data
  1384. msgs = [
  1385. msg.to_dict()
  1386. for msg in agent.conversation.messages
  1387. ]
  1388. input_tokens = num_tokens_from_messages(msgs[:-1])
  1389. output_tokens = num_tokens_from_messages([msgs[-1]])
  1390. await self.providers.database.conversations_handler.add_message(
  1391. conversation_id=conversation_id,
  1392. content=agent.conversation.messages[-1],
  1393. parent_id=message_id,
  1394. metadata={
  1395. "input_tokens": input_tokens,
  1396. "output_tokens": output_tokens,
  1397. },
  1398. )
  1399. # Generate conversation name if needed
  1400. if needs_initial_conversation_name:
  1401. try:
  1402. prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict())}"
  1403. conversation_name = (
  1404. (
  1405. await self.providers.llm.aget_completion(
  1406. [
  1407. {
  1408. "role": "system",
  1409. "content": prompt,
  1410. }
  1411. ],
  1412. GenerationConfig(
  1413. model=self.config.app.fast_llm
  1414. ),
  1415. )
  1416. )
  1417. .choices[0]
  1418. .message.content
  1419. )
  1420. await self.providers.database.conversations_handler.update_conversation(
  1421. conversation_id=conversation_id,
  1422. name=conversation_name,
  1423. )
  1424. except Exception as e:
  1425. logger.error(
  1426. f"Error generating conversation name: {e}"
  1427. )
  1428. return stream_response()
  1429. else:
  1430. for idx, msg in enumerate(messages):
  1431. if msg.content is None:
  1432. if (
  1433. hasattr(msg, "structured_content")
  1434. and msg.structured_content
  1435. ):
  1436. messages[idx].content = ""
  1437. else:
  1438. messages[idx].content = ""
  1439. # Non-streaming path
  1440. results = await agent.arun(
  1441. messages=messages,
  1442. system_instruction=system_instruction,
  1443. include_title_if_available=include_title_if_available,
  1444. )
  1445. # Process the agent results
  1446. if isinstance(results[-1], dict):
  1447. if results[-1].get("content") is None:
  1448. results[-1]["content"] = ""
  1449. assistant_message = Message(**results[-1])
  1450. elif isinstance(results[-1], Message):
  1451. assistant_message = results[-1]
  1452. if assistant_message.content is None:
  1453. assistant_message.content = ""
  1454. else:
  1455. assistant_message = Message(
  1456. role="assistant", content=str(results[-1])
  1457. )
  1458. # Get search results collector for citations
  1459. if hasattr(agent, "search_results_collector"):
  1460. collector = agent.search_results_collector
  1461. else:
  1462. collector = SearchResultsCollector()
  1463. # Extract content from the message
  1464. structured_content = assistant_message.structured_content
  1465. structured_content = (
  1466. structured_content[-1].get("text")
  1467. if structured_content
  1468. else None
  1469. )
  1470. raw_text = (
  1471. assistant_message.content or structured_content or ""
  1472. )
  1473. # Process citations
  1474. short_ids = extract_citations(raw_text or "")
  1475. final_citations = []
  1476. for sid in short_ids:
  1477. obj = collector.find_by_short_id(sid)
  1478. final_citations.append(
  1479. {
  1480. "id": sid,
  1481. "object": "citation",
  1482. "payload": dump_obj(obj) if obj else None,
  1483. }
  1484. )
  1485. # Persist in conversation DB
  1486. await (
  1487. self.providers.database.conversations_handler.add_message(
  1488. conversation_id=conversation_id,
  1489. content=assistant_message,
  1490. parent_id=message_id,
  1491. metadata={
  1492. "citations": final_citations,
  1493. "aggregated_search_result": json.dumps(
  1494. dump_collector(collector)
  1495. ),
  1496. },
  1497. )
  1498. )
  1499. # Generate conversation name if needed
  1500. if needs_initial_conversation_name:
  1501. conversation_name = None
  1502. try:
  1503. prompt = f"Generate a succinct name (3-6 words) for this conversation, given the first input mesasge here = {str(message.to_dict() if message else {})}"
  1504. conversation_name = (
  1505. (
  1506. await self.providers.llm.aget_completion(
  1507. [{"role": "system", "content": prompt}],
  1508. GenerationConfig(
  1509. model=self.config.app.fast_llm
  1510. ),
  1511. )
  1512. )
  1513. .choices[0]
  1514. .message.content
  1515. )
  1516. except Exception as e:
  1517. pass
  1518. finally:
  1519. await self.providers.database.conversations_handler.update_conversation(
  1520. conversation_id=conversation_id,
  1521. name=conversation_name or "",
  1522. )
  1523. tool_calls = []
  1524. if hasattr(agent, "tool_calls"):
  1525. if agent.tool_calls is not None:
  1526. tool_calls = agent.tool_calls
  1527. else:
  1528. logger.warning(
  1529. "agent.tool_calls is None, using empty list instead"
  1530. )
  1531. # Return the final response
  1532. return {
  1533. "messages": [
  1534. Message(
  1535. role="assistant",
  1536. content=assistant_message.content
  1537. or structured_content
  1538. or "",
  1539. metadata={
  1540. "citations": final_citations,
  1541. "tool_calls": tool_calls,
  1542. "aggregated_search_result": json.dumps(
  1543. dump_collector(collector)
  1544. ),
  1545. },
  1546. )
  1547. ],
  1548. "conversation_id": str(conversation_id),
  1549. }
  1550. except Exception as e:
  1551. logger.error(f"Error in agent response: {str(e)}")
  1552. if "NoneType" in str(e):
  1553. raise HTTPException(
  1554. status_code=502,
  1555. detail="Server not reachable or returned an invalid response",
  1556. ) from e
  1557. raise HTTPException(
  1558. status_code=500,
  1559. detail=f"Internal Server Error - {str(e)}",
  1560. ) from e
  1561. async def get_context(
  1562. self,
  1563. filters: dict[str, Any],
  1564. options: dict[str, Any],
  1565. ) -> list[dict[str, Any]]:
  1566. """
  1567. Return an ordered list of documents (with minimal overview fields),
  1568. plus all associated chunks in ascending chunk order.
  1569. Only the filters: owner_id, collection_ids, and document_id
  1570. are supported. If any other filter or operator is passed in,
  1571. we raise an error.
  1572. Args:
  1573. filters: A dictionary describing the allowed filters
  1574. (owner_id, collection_ids, document_id).
  1575. options: A dictionary with extra options, e.g. include_summary_embedding
  1576. or any custom flags for additional logic.
  1577. Returns:
  1578. A list of dicts, where each dict has:
  1579. {
  1580. "document": <DocumentResponse>,
  1581. "chunks": [ <chunk0>, <chunk1>, ... ]
  1582. }
  1583. """
  1584. # 2. Fetch matching documents
  1585. matching_docs = await self.providers.database.documents_handler.get_documents_overview(
  1586. offset=0,
  1587. limit=-1,
  1588. filters=filters,
  1589. include_summary_embedding=options.get(
  1590. "include_summary_embedding", False
  1591. ),
  1592. )
  1593. if not matching_docs["results"]:
  1594. return []
  1595. # 3. For each document, fetch associated chunks in ascending chunk order
  1596. results = []
  1597. for doc_response in matching_docs["results"]:
  1598. doc_id = doc_response.id
  1599. chunk_data = await self.providers.database.chunks_handler.list_document_chunks(
  1600. document_id=doc_id,
  1601. offset=0,
  1602. limit=-1, # get all chunks
  1603. include_vectors=False,
  1604. )
  1605. chunks = chunk_data["results"] # already sorted by chunk_order
  1606. doc_response.chunks = chunks
  1607. # 4. Build a returned structure that includes doc + chunks
  1608. results.append(doc_response.model_dump())
  1609. return results
  1610. def _parse_user_and_collection_filters(
  1611. self,
  1612. filters: dict[str, Any],
  1613. ):
  1614. ### TODO - Come up with smarter way to extract owner / collection ids for non-admin
  1615. filter_starts_with_and = filters.get("$and")
  1616. filter_starts_with_or = filters.get("$or")
  1617. if filter_starts_with_and:
  1618. try:
  1619. filter_starts_with_and_then_or = filter_starts_with_and[0][
  1620. "$or"
  1621. ]
  1622. user_id = filter_starts_with_and_then_or[0]["owner_id"]["$eq"]
  1623. collection_ids = [
  1624. str(ele)
  1625. for ele in filter_starts_with_and_then_or[1][
  1626. "collection_ids"
  1627. ]["$overlap"]
  1628. ]
  1629. return user_id, [str(ele) for ele in collection_ids]
  1630. except Exception as e:
  1631. logger.error(
  1632. f"Error: {e}.\n\n While"
  1633. + """ parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
  1634. )
  1635. return None, []
  1636. elif filter_starts_with_or:
  1637. try:
  1638. user_id = str(filter_starts_with_or[0]["owner_id"]["$eq"])
  1639. collection_ids = [
  1640. str(ele)
  1641. for ele in filter_starts_with_or[1]["collection_ids"][
  1642. "$overlap"
  1643. ]
  1644. ]
  1645. return user_id, [str(ele) for ele in collection_ids]
  1646. except Exception as e:
  1647. logger.error(
  1648. """Error parsing filters: expected format {'$or': [{'owner_id': {'$eq': 'uuid-string-here'}, 'collection_ids': {'$overlap': ['uuid-of-some-collection']}}]}, if you are a superuser then this error can be ignored."""
  1649. f"\n Instead, got: {filters}.\n\n Error: {e}"
  1650. )
  1651. return None, []
  1652. else:
  1653. # Admin user
  1654. return None, []
  1655. async def _build_documents_context(
  1656. self,
  1657. filter_user_id: Optional[UUID] = None,
  1658. max_summary_length: int = 128,
  1659. limit: int = 25,
  1660. reverse_order: bool = True,
  1661. ) -> str:
  1662. """
  1663. Fetches documents matching the given filters and returns a formatted string
  1664. enumerating them.
  1665. """
  1666. # We only want up to `limit` documents for brevity
  1667. docs_data = await self.providers.database.documents_handler.get_documents_overview(
  1668. offset=0,
  1669. limit=limit,
  1670. filter_user_ids=[filter_user_id] if filter_user_id else None,
  1671. include_summary_embedding=False,
  1672. sort_order="DESC" if reverse_order else "ASC",
  1673. )
  1674. found_max = False
  1675. if len(docs_data["results"]) == limit:
  1676. found_max = True
  1677. docs = docs_data["results"]
  1678. if not docs:
  1679. return "No documents found."
  1680. lines = []
  1681. for i, doc in enumerate(docs, start=1):
  1682. if (
  1683. not doc.summary
  1684. or doc.ingestion_status != IngestionStatus.SUCCESS
  1685. ):
  1686. lines.append(
  1687. f"[{i}] Title: {doc.title}, Summary: (Summary not available), Status:{doc.ingestion_status} ID: {doc.id}"
  1688. )
  1689. continue
  1690. # Build a line referencing the doc
  1691. title = doc.title or "(Untitled Document)"
  1692. lines.append(
  1693. f"[{i}] Title: {title}, Summary: {(doc.summary[0:max_summary_length] + ('...' if len(doc.summary) > max_summary_length else ''),)}, Total Tokens: {doc.total_tokens}, ID: {doc.id}"
  1694. )
  1695. if found_max:
  1696. lines.append(
  1697. f"Note: Displaying only the first {limit} documents. Use a filter to narrow down the search if more documents are required."
  1698. )
  1699. return "\n".join(lines)
  1700. async def _build_aware_system_instruction(
  1701. self,
  1702. max_tool_context_length: int = 10_000,
  1703. filter_user_id: Optional[UUID] = None,
  1704. filter_collection_ids: Optional[list[UUID]] = None,
  1705. model: Optional[str] = None,
  1706. use_system_context: bool = False,
  1707. mode: Optional[str] = "rag",
  1708. ) -> str:
  1709. """
  1710. High-level method that:
  1711. 1) builds the documents context
  1712. 2) builds the collections context
  1713. 3) loads the new `dynamic_reasoning_rag_agent` prompt
  1714. """
  1715. date_str = str(datetime.now().strftime("%m/%d/%Y"))
  1716. # "dynamic_rag_agent" // "static_rag_agent"
  1717. if mode == "rag":
  1718. prompt_name = (
  1719. self.config.agent.rag_agent_dynamic_prompt
  1720. if use_system_context
  1721. else self.config.agent.rag_rag_agent_static_prompt
  1722. )
  1723. else:
  1724. prompt_name = "static_research_agent"
  1725. return await self.providers.database.prompts_handler.get_cached_prompt(
  1726. # We use custom tooling and a custom agent to handle gemini models
  1727. prompt_name,
  1728. inputs={
  1729. "date": date_str,
  1730. },
  1731. )
  1732. if model is not None and ("deepseek" in model):
  1733. prompt_name = f"{prompt_name}_xml_tooling"
  1734. if use_system_context:
  1735. doc_context_str = await self._build_documents_context(
  1736. filter_user_id=filter_user_id,
  1737. )
  1738. logger.debug(f"Loading prompt {prompt_name}")
  1739. # Now fetch the prompt from the database prompts handler
  1740. # This relies on your "rag_agent_extended" existing with
  1741. # placeholders: date, document_context
  1742. system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
  1743. # We use custom tooling and a custom agent to handle gemini models
  1744. prompt_name,
  1745. inputs={
  1746. "date": date_str,
  1747. "max_tool_context_length": max_tool_context_length,
  1748. "document_context": doc_context_str,
  1749. },
  1750. )
  1751. else:
  1752. system_prompt = await self.providers.database.prompts_handler.get_cached_prompt(
  1753. prompt_name,
  1754. inputs={
  1755. "date": date_str,
  1756. },
  1757. )
  1758. logger.debug(f"Running agent with system prompt = {system_prompt}")
  1759. return system_prompt
  1760. async def _perform_web_search(
  1761. self,
  1762. query: str,
  1763. search_settings: SearchSettings = SearchSettings(),
  1764. ) -> AggregateSearchResult:
  1765. """
  1766. Perform a web search using an external search engine API (Serper).
  1767. Args:
  1768. query: The search query string
  1769. search_settings: Optional search settings to customize the search
  1770. Returns:
  1771. AggregateSearchResult containing web search results
  1772. """
  1773. try:
  1774. # Import the Serper client here to avoid circular imports
  1775. from core.utils.serper import SerperClient
  1776. # Initialize the Serper client
  1777. serper_client = SerperClient()
  1778. # Perform the raw search using Serper API
  1779. raw_results = serper_client.get_raw(query)
  1780. # Process the raw results into a WebSearchResult object
  1781. web_response = WebSearchResult.from_serper_results(raw_results)
  1782. # Create an AggregateSearchResult with the web search results
  1783. # FIXME: Need to understand why we would have had this referencing only web_response.organic_results
  1784. agg_result = AggregateSearchResult(
  1785. web_search_results=[web_response]
  1786. )
  1787. # Log the search for monitoring purposes
  1788. logger.debug(f"Web search completed for query: {query}")
  1789. logger.debug(
  1790. f"Found {len(web_response.organic_results)} web results"
  1791. )
  1792. return agg_result
  1793. except Exception as e:
  1794. logger.error(f"Error performing web search: {str(e)}")
  1795. # Return empty results rather than failing completely
  1796. return AggregateSearchResult(
  1797. chunk_search_results=None,
  1798. graph_search_results=None,
  1799. web_search_results=[],
  1800. )
  1801. class RetrievalServiceAdapter:
  1802. @staticmethod
  1803. def _parse_user_data(user_data):
  1804. if isinstance(user_data, str):
  1805. try:
  1806. user_data = json.loads(user_data)
  1807. except json.JSONDecodeError as e:
  1808. raise ValueError(
  1809. f"Invalid user data format: {user_data}"
  1810. ) from e
  1811. return User.from_dict(user_data)
  1812. @staticmethod
  1813. def prepare_search_input(
  1814. query: str,
  1815. search_settings: SearchSettings,
  1816. user: User,
  1817. ) -> dict:
  1818. return {
  1819. "query": query,
  1820. "search_settings": search_settings.to_dict(),
  1821. "user": user.to_dict(),
  1822. }
  1823. @staticmethod
  1824. def parse_search_input(data: dict):
  1825. return {
  1826. "query": data["query"],
  1827. "search_settings": SearchSettings.from_dict(
  1828. data["search_settings"]
  1829. ),
  1830. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  1831. }
  1832. @staticmethod
  1833. def prepare_rag_input(
  1834. query: str,
  1835. search_settings: SearchSettings,
  1836. rag_generation_config: GenerationConfig,
  1837. task_prompt: Optional[str],
  1838. include_web_search: bool,
  1839. user: User,
  1840. ) -> dict:
  1841. return {
  1842. "query": query,
  1843. "search_settings": search_settings.to_dict(),
  1844. "rag_generation_config": rag_generation_config.to_dict(),
  1845. "task_prompt": task_prompt,
  1846. "include_web_search": include_web_search,
  1847. "user": user.to_dict(),
  1848. }
  1849. @staticmethod
  1850. def parse_rag_input(data: dict):
  1851. return {
  1852. "query": data["query"],
  1853. "search_settings": SearchSettings.from_dict(
  1854. data["search_settings"]
  1855. ),
  1856. "rag_generation_config": GenerationConfig.from_dict(
  1857. data["rag_generation_config"]
  1858. ),
  1859. "task_prompt": data["task_prompt"],
  1860. "include_web_search": data["include_web_search"],
  1861. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  1862. }
  1863. @staticmethod
  1864. def prepare_agent_input(
  1865. message: Message,
  1866. search_settings: SearchSettings,
  1867. rag_generation_config: GenerationConfig,
  1868. task_prompt: Optional[str],
  1869. include_title_if_available: bool,
  1870. user: User,
  1871. conversation_id: Optional[str] = None,
  1872. ) -> dict:
  1873. return {
  1874. "message": message.to_dict(),
  1875. "search_settings": search_settings.to_dict(),
  1876. "rag_generation_config": rag_generation_config.to_dict(),
  1877. "task_prompt": task_prompt,
  1878. "include_title_if_available": include_title_if_available,
  1879. "user": user.to_dict(),
  1880. "conversation_id": conversation_id,
  1881. }
  1882. @staticmethod
  1883. def parse_agent_input(data: dict):
  1884. return {
  1885. "message": Message.from_dict(data["message"]),
  1886. "search_settings": SearchSettings.from_dict(
  1887. data["search_settings"]
  1888. ),
  1889. "rag_generation_config": GenerationConfig.from_dict(
  1890. data["rag_generation_config"]
  1891. ),
  1892. "task_prompt": data["task_prompt"],
  1893. "include_title_if_available": data["include_title_if_available"],
  1894. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  1895. "conversation_id": data.get("conversation_id"),
  1896. }