retrieval_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import json
  2. import logging
  3. import time
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import HTTPException
  7. from core import R2RStreamingRAGAgent
  8. from core.base import (
  9. DocumentResponse,
  10. EmbeddingPurpose,
  11. GenerationConfig,
  12. GraphSearchSettings,
  13. Message,
  14. R2RException,
  15. RunManager,
  16. SearchMode,
  17. SearchSettings,
  18. manage_run,
  19. to_async_generator,
  20. )
  21. from core.base.api.models import CombinedSearchResponse, RAGResponse, User
  22. from core.base.logger.base import RunType
  23. from core.telemetry.telemetry_decorator import telemetry_event
  24. from shared.api.models.management.responses import MessageResponse
  25. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  26. from ..config import R2RConfig
  27. from .base import Service
  28. logger = logging.getLogger()
  29. class RetrievalService(Service):
  30. def __init__(
  31. self,
  32. config: R2RConfig,
  33. providers: R2RProviders,
  34. pipes: R2RPipes,
  35. pipelines: R2RPipelines,
  36. agents: R2RAgents,
  37. run_manager: RunManager,
  38. ):
  39. super().__init__(
  40. config,
  41. providers,
  42. pipes,
  43. pipelines,
  44. agents,
  45. run_manager,
  46. )
  47. @telemetry_event("Search")
  48. async def search( # TODO - rename to 'search_chunks'
  49. self,
  50. query: str,
  51. search_settings: SearchSettings = SearchSettings(),
  52. *args,
  53. **kwargs,
  54. ) -> CombinedSearchResponse:
  55. async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
  56. t0 = time.time()
  57. if (
  58. search_settings.use_semantic_search
  59. and self.config.database.provider is None
  60. ):
  61. raise R2RException(
  62. status_code=400,
  63. message="Vector search is not enabled in the configuration.",
  64. )
  65. if (
  66. (
  67. search_settings.use_semantic_search
  68. and search_settings.use_fulltext_search
  69. )
  70. or search_settings.use_hybrid_search
  71. ) and not search_settings.hybrid_settings:
  72. raise R2RException(
  73. status_code=400,
  74. message="Hybrid search settings must be specified in the input configuration.",
  75. )
  76. # TODO - Remove these transforms once we have a better way to handle this
  77. for filter, value in search_settings.filters.items():
  78. if isinstance(value, UUID):
  79. search_settings.filters[filter] = str(value)
  80. merged_kwargs = {
  81. "input": to_async_generator([query]),
  82. "state": None,
  83. "search_settings": search_settings,
  84. "run_manager": self.run_manager,
  85. **kwargs,
  86. }
  87. results = await self.pipelines.search_pipeline.run(
  88. *args,
  89. **merged_kwargs,
  90. )
  91. t1 = time.time()
  92. latency = f"{t1 - t0:.2f}"
  93. return results.as_dict()
  94. @telemetry_event("SearchDocuments")
  95. async def search_documents(
  96. self,
  97. query: str,
  98. settings: SearchSettings,
  99. query_embedding: Optional[list[float]] = None,
  100. ) -> list[DocumentResponse]:
  101. return (
  102. await self.providers.database.documents_handler.search_documents(
  103. query_text=query,
  104. settings=settings,
  105. query_embedding=query_embedding,
  106. )
  107. )
  108. @telemetry_event("Completion")
  109. async def completion(
  110. self,
  111. messages: list[dict],
  112. generation_config: GenerationConfig,
  113. *args,
  114. **kwargs,
  115. ):
  116. return await self.providers.llm.aget_completion(
  117. [message.to_dict() for message in messages],
  118. generation_config,
  119. *args,
  120. **kwargs,
  121. )
  122. @telemetry_event("Embedding")
  123. async def embedding(
  124. self,
  125. text: str,
  126. ):
  127. return await self.providers.embedding.async_get_embedding(text=text)
  128. @telemetry_event("RAG")
  129. async def rag(
  130. self,
  131. query: str,
  132. rag_generation_config: GenerationConfig,
  133. search_settings: SearchSettings = SearchSettings(),
  134. *args,
  135. **kwargs,
  136. ) -> RAGResponse:
  137. async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
  138. try:
  139. # TODO - Remove these transforms once we have a better way to handle this
  140. for (
  141. filter,
  142. value,
  143. ) in search_settings.filters.items():
  144. if isinstance(value, UUID):
  145. search_settings.filters[filter] = str(value)
  146. if rag_generation_config.stream:
  147. return await self.stream_rag_response(
  148. query,
  149. rag_generation_config,
  150. search_settings,
  151. *args,
  152. **kwargs,
  153. )
  154. merged_kwargs = {
  155. "input": to_async_generator([query]),
  156. "state": None,
  157. "search_settings": search_settings,
  158. "run_manager": self.run_manager,
  159. "rag_generation_config": rag_generation_config,
  160. **kwargs,
  161. }
  162. results = await self.pipelines.rag_pipeline.run(
  163. *args,
  164. **merged_kwargs,
  165. )
  166. if len(results) == 0:
  167. raise R2RException(
  168. status_code=404, message="No results found"
  169. )
  170. if len(results) > 1:
  171. logger.warning(
  172. f"Multiple results found for query: {query}"
  173. )
  174. # unpack the first result
  175. return results[0]
  176. except Exception as e:
  177. logger.error(f"Pipeline error: {str(e)}")
  178. if "NoneType" in str(e):
  179. raise HTTPException(
  180. status_code=502,
  181. detail="Remote server not reachable or returned an invalid response",
  182. ) from e
  183. raise HTTPException(
  184. status_code=500, detail="Internal Server Error"
  185. ) from e
  186. async def stream_rag_response(
  187. self,
  188. query,
  189. rag_generation_config,
  190. search_settings,
  191. *args,
  192. **kwargs,
  193. ):
  194. async def stream_response():
  195. async with manage_run(self.run_manager, "rag"):
  196. merged_kwargs = {
  197. "input": to_async_generator([query]),
  198. "state": None,
  199. "run_manager": self.run_manager,
  200. "search_settings": search_settings,
  201. "rag_generation_config": rag_generation_config,
  202. **kwargs,
  203. }
  204. async for (
  205. chunk
  206. ) in await self.pipelines.streaming_rag_pipeline.run(
  207. *args,
  208. **merged_kwargs,
  209. ):
  210. yield chunk
  211. return stream_response()
  212. @telemetry_event("Agent")
  213. async def agent(
  214. self,
  215. rag_generation_config: GenerationConfig,
  216. search_settings: SearchSettings = SearchSettings(),
  217. task_prompt_override: Optional[str] = None,
  218. include_title_if_available: Optional[bool] = False,
  219. conversation_id: Optional[UUID] = None,
  220. message: Optional[Message] = None,
  221. messages: Optional[list[Message]] = None,
  222. ):
  223. async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
  224. try:
  225. if message and messages:
  226. raise R2RException(
  227. status_code=400,
  228. message="Only one of message or messages should be provided",
  229. )
  230. if not message and not messages:
  231. raise R2RException(
  232. status_code=400,
  233. message="Either message or messages should be provided",
  234. )
  235. # Ensure 'message' is a Message instance
  236. if message and not isinstance(message, Message):
  237. if isinstance(message, dict):
  238. message = Message.from_dict(message)
  239. else:
  240. raise R2RException(
  241. status_code=400,
  242. message="""
  243. Invalid message format. The expected format contains:
  244. role: MessageType | 'system' | 'user' | 'assistant' | 'function'
  245. content: Optional[str]
  246. name: Optional[str]
  247. function_call: Optional[dict[str, Any]]
  248. tool_calls: Optional[list[dict[str, Any]]]
  249. """,
  250. )
  251. # Ensure 'messages' is a list of Message instances
  252. if messages:
  253. processed_messages = []
  254. for message in messages:
  255. if isinstance(message, Message):
  256. processed_messages.append(message)
  257. elif hasattr(message, "dict"):
  258. processed_messages.append(
  259. Message.from_dict(message.dict())
  260. )
  261. elif isinstance(message, dict):
  262. processed_messages.append(
  263. Message.from_dict(message)
  264. )
  265. else:
  266. processed_messages.append(
  267. Message.from_dict(str(message))
  268. )
  269. messages = processed_messages
  270. else:
  271. messages = []
  272. # Transform UUID filters to strings
  273. for filter_key, value in search_settings.filters.items():
  274. if isinstance(value, UUID):
  275. search_settings.filters[filter_key] = str(value)
  276. ids = []
  277. if conversation_id: # Fetch the existing conversation
  278. try:
  279. conversation = await self.providers.database.conversations_handler.get_conversations_overview(
  280. offset=0,
  281. limit=1,
  282. conversation_ids=[conversation_id],
  283. )
  284. except Exception as e:
  285. logger.error(f"Error fetching conversation: {str(e)}")
  286. if conversation is not None:
  287. messages_from_conversation: list[Message] = []
  288. for message_response in conversation:
  289. if isinstance(message_response, MessageResponse):
  290. messages_from_conversation.append(
  291. message_response.message
  292. )
  293. ids.append(message_response.id)
  294. else:
  295. logger.warning(
  296. f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
  297. )
  298. messages = messages_from_conversation + messages
  299. else: # Create new conversation
  300. conversation_response = (
  301. await self.providers.database.conversations_handler.create_conversation()
  302. )
  303. conversation_id = conversation_response.id
  304. if message:
  305. messages.append(message)
  306. if not messages:
  307. raise R2RException(
  308. status_code=400,
  309. message="No messages to process",
  310. )
  311. current_message = messages[-1]
  312. # Save the new message to the conversation
  313. parent_id = ids[-1] if ids else None
  314. message_response = await self.providers.database.conversations_handler.add_message(
  315. conversation_id=conversation_id,
  316. content=current_message,
  317. parent_id=parent_id,
  318. )
  319. message_id = (
  320. message_response.id
  321. if message_response is not None
  322. else None
  323. )
  324. if rag_generation_config.stream:
  325. async def stream_response():
  326. async with manage_run(self.run_manager, "rag_agent"):
  327. agent = R2RStreamingRAGAgent(
  328. database_provider=self.providers.database,
  329. llm_provider=self.providers.llm,
  330. config=self.config.agent,
  331. search_pipeline=self.pipelines.search_pipeline,
  332. )
  333. async for chunk in agent.arun(
  334. messages=messages,
  335. system_instruction=task_prompt_override,
  336. search_settings=search_settings,
  337. rag_generation_config=rag_generation_config,
  338. include_title_if_available=include_title_if_available,
  339. ):
  340. yield chunk
  341. return stream_response()
  342. results = await self.agents.rag_agent.arun(
  343. messages=messages,
  344. system_instruction=task_prompt_override,
  345. search_settings=search_settings,
  346. rag_generation_config=rag_generation_config,
  347. include_title_if_available=include_title_if_available,
  348. )
  349. # Save the assistant's reply to the conversation
  350. if isinstance(results[-1], dict):
  351. assistant_message = Message(**results[-1])
  352. elif isinstance(results[-1], Message):
  353. assistant_message = results[-1]
  354. else:
  355. assistant_message = Message(
  356. role="assistant", content=str(results[-1])
  357. )
  358. await self.providers.database.conversations_handler.add_message(
  359. conversation_id=conversation_id,
  360. content=assistant_message,
  361. parent_id=message_id,
  362. )
  363. return {
  364. "messages": results,
  365. "conversation_id": str(
  366. conversation_id
  367. ), # Ensure it's a string
  368. }
  369. except Exception as e:
  370. logger.error(f"Error in agent response: {str(e)}")
  371. if "NoneType" in str(e):
  372. raise HTTPException(
  373. status_code=502,
  374. detail="Server not reachable or returned an invalid response",
  375. )
  376. raise HTTPException(
  377. status_code=500,
  378. detail=f"Internal Server Error - {str(e)}",
  379. )
  380. class RetrievalServiceAdapter:
  381. @staticmethod
  382. def _parse_user_data(user_data):
  383. if isinstance(user_data, str):
  384. try:
  385. user_data = json.loads(user_data)
  386. except json.JSONDecodeError:
  387. raise ValueError(f"Invalid user data format: {user_data}")
  388. return User.from_dict(user_data)
  389. @staticmethod
  390. def prepare_search_input(
  391. query: str,
  392. search_settings: SearchSettings,
  393. user: User,
  394. ) -> dict:
  395. return {
  396. "query": query,
  397. "search_settings": search_settings.to_dict(),
  398. "user": user.to_dict(),
  399. }
  400. @staticmethod
  401. def parse_search_input(data: dict):
  402. return {
  403. "query": data["query"],
  404. "search_settings": SearchSettings.from_dict(
  405. data["search_settings"]
  406. ),
  407. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  408. }
  409. @staticmethod
  410. def prepare_rag_input(
  411. query: str,
  412. search_settings: SearchSettings,
  413. rag_generation_config: GenerationConfig,
  414. task_prompt_override: Optional[str],
  415. user: User,
  416. ) -> dict:
  417. return {
  418. "query": query,
  419. "search_settings": search_settings.to_dict(),
  420. "rag_generation_config": rag_generation_config.to_dict(),
  421. "task_prompt_override": task_prompt_override,
  422. "user": user.to_dict(),
  423. }
  424. @staticmethod
  425. def parse_rag_input(data: dict):
  426. return {
  427. "query": data["query"],
  428. "search_settings": SearchSettings.from_dict(
  429. data["search_settings"]
  430. ),
  431. "rag_generation_config": GenerationConfig.from_dict(
  432. data["rag_generation_config"]
  433. ),
  434. "task_prompt_override": data["task_prompt_override"],
  435. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  436. }
  437. @staticmethod
  438. def prepare_agent_input(
  439. message: Message,
  440. search_settings: SearchSettings,
  441. rag_generation_config: GenerationConfig,
  442. task_prompt_override: Optional[str],
  443. include_title_if_available: bool,
  444. user: User,
  445. conversation_id: Optional[str] = None,
  446. ) -> dict:
  447. return {
  448. "message": message.to_dict(),
  449. "search_settings": search_settings.to_dict(),
  450. "rag_generation_config": rag_generation_config.to_dict(),
  451. "task_prompt_override": task_prompt_override,
  452. "include_title_if_available": include_title_if_available,
  453. "user": user.to_dict(),
  454. "conversation_id": conversation_id,
  455. }
  456. @staticmethod
  457. def parse_agent_input(data: dict):
  458. return {
  459. "message": Message.from_dict(data["message"]),
  460. "search_settings": SearchSettings.from_dict(
  461. data["search_settings"]
  462. ),
  463. "rag_generation_config": GenerationConfig.from_dict(
  464. data["rag_generation_config"]
  465. ),
  466. "task_prompt_override": data["task_prompt_override"],
  467. "include_title_if_available": data["include_title_if_available"],
  468. "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
  469. "conversation_id": data.get("conversation_id"),
  470. }