retrieval_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. import json
  2. import logging
  3. import time
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import HTTPException
  7. from core import R2RStreamingRAGAgent
  8. from core.base import (
  9. DocumentResponse,
  10. GenerationConfig,
  11. Message,
  12. R2RException,
  13. RunManager,
  14. SearchSettings,
  15. manage_run,
  16. to_async_generator,
  17. )
  18. from core.base.api.models import CombinedSearchResponse, RAGResponse, User
  19. from core.telemetry.telemetry_decorator import telemetry_event
  20. from shared.api.models.management.responses import MessageResponse
  21. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  22. from ..config import R2RConfig
  23. from .base import Service
  24. logger = logging.getLogger()
  25. class RetrievalService(Service):
  26. def __init__(
  27. self,
  28. config: R2RConfig,
  29. providers: R2RProviders,
  30. pipes: R2RPipes,
  31. pipelines: R2RPipelines,
  32. agents: R2RAgents,
  33. run_manager: RunManager,
  34. ):
  35. super().__init__(
  36. config,
  37. providers,
  38. pipes,
  39. pipelines,
  40. agents,
  41. run_manager,
  42. )
  43. @telemetry_event("Search")
  44. async def search( # TODO - rename to 'search_chunks'
  45. self,
  46. query: str,
  47. search_settings: SearchSettings = SearchSettings(),
  48. *args,
  49. **kwargs,
  50. ) -> CombinedSearchResponse:
  51. async with manage_run(self.run_manager) as run_id:
  52. t0 = time.time()
  53. if (
  54. search_settings.use_semantic_search
  55. and self.config.database.provider is None
  56. ):
  57. raise R2RException(
  58. status_code=400,
  59. message="Vector search is not enabled in the configuration.",
  60. )
  61. if (
  62. (
  63. search_settings.use_semantic_search
  64. and search_settings.use_fulltext_search
  65. )
  66. or search_settings.use_hybrid_search
  67. ) and not search_settings.hybrid_settings:
  68. raise R2RException(
  69. status_code=400,
  70. message="Hybrid search settings must be specified in the input configuration.",
  71. )
  72. # TODO - Remove these transforms once we have a better way to handle this
  73. for filter, value in search_settings.filters.items():
  74. if isinstance(value, UUID):
  75. search_settings.filters[filter] = str(value)
  76. merged_kwargs = {
  77. "input": to_async_generator([query]),
  78. "state": None,
  79. "search_settings": search_settings,
  80. "run_manager": self.run_manager,
  81. **kwargs,
  82. }
  83. results = await self.pipelines.search_pipeline.run(
  84. *args,
  85. **merged_kwargs,
  86. )
  87. t1 = time.time()
  88. latency = f"{t1 - t0:.2f}"
  89. return results.as_dict()
  90. @telemetry_event("SearchDocuments")
  91. async def search_documents(
  92. self,
  93. query: str,
  94. settings: SearchSettings,
  95. query_embedding: Optional[list[float]] = None,
  96. ) -> list[DocumentResponse]:
  97. return (
  98. await self.providers.database.documents_handler.search_documents(
  99. query_text=query,
  100. settings=settings,
  101. query_embedding=query_embedding,
  102. )
  103. )
  104. @telemetry_event("Completion")
  105. async def completion(
  106. self,
  107. messages: list[dict],
  108. generation_config: GenerationConfig,
  109. *args,
  110. **kwargs,
  111. ):
  112. return await self.providers.llm.aget_completion(
  113. [message.to_dict() for message in messages],
  114. generation_config,
  115. *args,
  116. **kwargs,
  117. )
  118. @telemetry_event("Embedding")
  119. async def embedding(
  120. self,
  121. text: str,
  122. ):
  123. return await self.providers.embedding.async_get_embedding(text=text)
  124. @telemetry_event("RAG")
  125. async def rag(
  126. self,
  127. query: str,
  128. rag_generation_config: GenerationConfig,
  129. search_settings: SearchSettings = SearchSettings(),
  130. *args,
  131. **kwargs,
  132. ) -> RAGResponse:
  133. async with manage_run(self.run_manager) as run_id:
  134. try:
  135. # TODO - Remove these transforms once we have a better way to handle this
  136. for (
  137. filter,
  138. value,
  139. ) in search_settings.filters.items():
  140. if isinstance(value, UUID):
  141. search_settings.filters[filter] = str(value)
  142. if rag_generation_config.stream:
  143. return await self.stream_rag_response(
  144. query,
  145. rag_generation_config,
  146. search_settings,
  147. *args,
  148. **kwargs,
  149. )
  150. merged_kwargs = {
  151. "input": to_async_generator([query]),
  152. "state": None,
  153. "search_settings": search_settings,
  154. "run_manager": self.run_manager,
  155. "rag_generation_config": rag_generation_config,
  156. **kwargs,
  157. }
  158. results = await self.pipelines.rag_pipeline.run(
  159. *args,
  160. **merged_kwargs,
  161. )
  162. if len(results) == 0:
  163. raise R2RException(
  164. status_code=404, message="No results found"
  165. )
  166. if len(results) > 1:
  167. logger.warning(
  168. f"Multiple results found for query: {query}"
  169. )
  170. # unpack the first result
  171. return results[0]
  172. except Exception as e:
  173. logger.error(f"Pipeline error: {str(e)}")
  174. if "NoneType" in str(e):
  175. raise HTTPException(
  176. status_code=502,
  177. detail="Remote server not reachable or returned an invalid response",
  178. ) from e
  179. raise HTTPException(
  180. status_code=500, detail="Internal Server Error"
  181. ) from e
  182. async def stream_rag_response(
  183. self,
  184. query,
  185. rag_generation_config,
  186. search_settings,
  187. *args,
  188. **kwargs,
  189. ):
  190. async def stream_response():
  191. async with manage_run(self.run_manager, "rag"):
  192. merged_kwargs = {
  193. "input": to_async_generator([query]),
  194. "state": None,
  195. "run_manager": self.run_manager,
  196. "search_settings": search_settings,
  197. "rag_generation_config": rag_generation_config,
  198. **kwargs,
  199. }
  200. async for (
  201. chunk
  202. ) in await self.pipelines.streaming_rag_pipeline.run(
  203. *args,
  204. **merged_kwargs,
  205. ):
  206. yield chunk
  207. return stream_response()
  208. @telemetry_event("Agent")
  209. async def agent(
  210. self,
  211. rag_generation_config: GenerationConfig,
  212. search_settings: SearchSettings = SearchSettings(),
  213. task_prompt_override: Optional[str] = None,
  214. include_title_if_available: Optional[bool] = False,
  215. conversation_id: Optional[UUID] = None,
  216. message: Optional[Message] = None,
  217. messages: Optional[list[Message]] = None,
  218. ):
  219. async with manage_run(self.run_manager) as run_id:
  220. try:
  221. if message and messages:
  222. raise R2RException(
  223. status_code=400,
  224. message="Only one of message or messages should be provided",
  225. )
  226. if not message and not messages:
  227. raise R2RException(
  228. status_code=400,
  229. message="Either message or messages should be provided",
  230. )
  231. # Ensure 'message' is a Message instance
  232. if message and not isinstance(message, Message):
  233. if isinstance(message, dict):
  234. message = Message.from_dict(message)
  235. else:
  236. raise R2RException(
  237. status_code=400,
  238. message="""
  239. Invalid message format. The expected format contains:
  240. role: MessageType | 'system' | 'user' | 'assistant' | 'function'
  241. content: Optional[str]
  242. name: Optional[str]
  243. function_call: Optional[dict[str, Any]]
  244. tool_calls: Optional[list[dict[str, Any]]]
  245. """,
  246. )
  247. # Ensure 'messages' is a list of Message instances
  248. if messages:
  249. processed_messages = []
  250. for message in messages:
  251. if isinstance(message, Message):
  252. processed_messages.append(message)
  253. elif hasattr(message, "dict"):
  254. processed_messages.append(
  255. Message.from_dict(message.dict())
  256. )
  257. elif isinstance(message, dict):
  258. processed_messages.append(
  259. Message.from_dict(message)
  260. )
  261. else:
  262. processed_messages.append(
  263. Message.from_dict(str(message))
  264. )
  265. messages = processed_messages
  266. else:
  267. messages = []
  268. # Transform UUID filters to strings
  269. for filter_key, value in search_settings.filters.items():
  270. if isinstance(value, UUID):
  271. search_settings.filters[filter_key] = str(value)
  272. ids = []
  273. if conversation_id: # Fetch the existing conversation
  274. try:
  275. conversation_messages = await self.providers.database.conversations_handler.get_conversation(
  276. conversation_id=conversation_id,
  277. )
  278. except Exception as e:
  279. logger.error(f"Error fetching conversation: {str(e)}")
  280. if conversation_messages is not None:
  281. messages_from_conversation: list[Message] = []
  282. for message_response in conversation_messages:
  283. if isinstance(message_response, MessageResponse):
  284. messages_from_conversation.append(
  285. message_response.message
  286. )
  287. ids.append(message_response.id)
  288. else:
  289. logger.warning(
  290. f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
  291. )
  292. messages = messages_from_conversation + messages
  293. else: # Create new conversation
  294. conversation_response = (
  295. await self.providers.database.conversations_handler.create_conversation()
  296. )
  297. conversation_id = conversation_response.id
  298. if message:
  299. messages.append(message)
  300. if not messages:
  301. raise R2RException(
  302. status_code=400,
  303. message="No messages to process",
  304. )
  305. current_message = messages[-1]
  306. # Save the new message to the conversation
  307. parent_id = ids[-1] if ids else None
  308. message_response = await self.providers.database.conversations_handler.add_message(
  309. conversation_id=conversation_id,
  310. content=current_message,
  311. parent_id=parent_id,
  312. )
  313. message_id = (
  314. message_response.id
  315. if message_response is not None
  316. else None
  317. )
  318. if rag_generation_config.stream:
  319. async def stream_response():
  320. async with manage_run(self.run_manager, "rag_agent"):
  321. agent = R2RStreamingRAGAgent(
  322. database_provider=self.providers.database,
  323. llm_provider=self.providers.llm,
  324. config=self.config.agent,
  325. search_pipeline=self.pipelines.search_pipeline,
  326. )
  327. async for chunk in agent.arun(
  328. messages=messages,
  329. system_instruction=task_prompt_override,
  330. search_settings=search_settings,
  331. rag_generation_config=rag_generation_config,
  332. include_title_if_available=include_title_if_available,
  333. ):
  334. yield chunk
  335. return stream_response()
  336. results = await self.agents.rag_agent.arun(
  337. messages=messages,
  338. system_instruction=task_prompt_override,
  339. search_settings=search_settings,
  340. rag_generation_config=rag_generation_config,
  341. include_title_if_available=include_title_if_available,
  342. )
  343. # Save the assistant's reply to the conversation
  344. if isinstance(results[-1], dict):
  345. assistant_message = Message(**results[-1])
  346. elif isinstance(results[-1], Message):
  347. assistant_message = results[-1]
  348. else:
  349. assistant_message = Message(
  350. role="assistant", content=str(results[-1])
  351. )
  352. await self.providers.database.conversations_handler.add_message(
  353. conversation_id=conversation_id,
  354. content=assistant_message,
  355. parent_id=message_id,
  356. )
  357. return {
  358. "messages": results,
  359. "conversation_id": str(
  360. conversation_id
  361. ), # Ensure it's a string
  362. }
  363. except Exception as e:
  364. logger.error(f"Error in agent response: {str(e)}")
  365. if "NoneType" in str(e):
  366. raise HTTPException(
  367. status_code=502,
  368. detail="Server not reachable or returned an invalid response",
  369. )
  370. raise HTTPException(
  371. status_code=500,
  372. detail=f"Internal Server Error - {str(e)}",
  373. )
  374. class RetrievalServiceAdapter:
  375. @staticmethod
  376. def _parse_user_data(user_data):
  377. if isinstance(user_data, str):
  378. try:
  379. user_data = json.loads(user_data)
  380. except json.JSONDecodeError:
  381. raise ValueError(f"Invalid user data format: {user_data}")
  382. return User.from_dict(user_data)
  383. @staticmethod
  384. def prepare_search_input(
  385. query: str,
  386. search_settings: SearchSettings,
  387. user: User,
  388. ) -> dict:
  389. return {
  390. "query": query,
  391. "search_settings": search_settings.to_dict(),
  392. "user": user.to_dict(),
  393. }
  394. @staticmethod
  395. def parse_search_input(data: dict):
  396. return {
  397. "query": data["query"],
  398. "search_settings": SearchSettings.from_dict(
  399. data["search_settings"]
  400. ),
  401. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  402. }
  403. @staticmethod
  404. def prepare_rag_input(
  405. query: str,
  406. search_settings: SearchSettings,
  407. rag_generation_config: GenerationConfig,
  408. task_prompt_override: Optional[str],
  409. user: User,
  410. ) -> dict:
  411. return {
  412. "query": query,
  413. "search_settings": search_settings.to_dict(),
  414. "rag_generation_config": rag_generation_config.to_dict(),
  415. "task_prompt_override": task_prompt_override,
  416. "user": user.to_dict(),
  417. }
  418. @staticmethod
  419. def parse_rag_input(data: dict):
  420. return {
  421. "query": data["query"],
  422. "search_settings": SearchSettings.from_dict(
  423. data["search_settings"]
  424. ),
  425. "rag_generation_config": GenerationConfig.from_dict(
  426. data["rag_generation_config"]
  427. ),
  428. "task_prompt_override": data["task_prompt_override"],
  429. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  430. }
  431. @staticmethod
  432. def prepare_agent_input(
  433. message: Message,
  434. search_settings: SearchSettings,
  435. rag_generation_config: GenerationConfig,
  436. task_prompt_override: Optional[str],
  437. include_title_if_available: bool,
  438. user: User,
  439. conversation_id: Optional[str] = None,
  440. ) -> dict:
  441. return {
  442. "message": message.to_dict(),
  443. "search_settings": search_settings.to_dict(),
  444. "rag_generation_config": rag_generation_config.to_dict(),
  445. "task_prompt_override": task_prompt_override,
  446. "include_title_if_available": include_title_if_available,
  447. "user": user.to_dict(),
  448. "conversation_id": conversation_id,
  449. }
  450. @staticmethod
  451. def parse_agent_input(data: dict):
  452. return {
  453. "message": Message.from_dict(data["message"]),
  454. "search_settings": SearchSettings.from_dict(
  455. data["search_settings"]
  456. ),
  457. "rag_generation_config": GenerationConfig.from_dict(
  458. data["rag_generation_config"]
  459. ),
  460. "task_prompt_override": data["task_prompt_override"],
  461. "include_title_if_available": data["include_title_if_available"],
  462. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  463. "conversation_id": data.get("conversation_id"),
  464. }