retrieval.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from typing import Any, AsyncGenerator, Optional
  2. from uuid import UUID
  3. from shared.api.models import (
  4. WrappedAgentResponse,
  5. WrappedEmbeddingResponse,
  6. WrappedLLMChatCompletion,
  7. WrappedRAGResponse,
  8. WrappedSearchResponse,
  9. )
  10. from ..models import (
  11. CitationEvent,
  12. FinalAnswerEvent,
  13. GenerationConfig,
  14. Message,
  15. MessageEvent,
  16. SearchMode,
  17. SearchResultsEvent,
  18. SearchSettings,
  19. ThinkingEvent,
  20. ToolCallEvent,
  21. ToolResultEvent,
  22. UnknownEvent,
  23. )
  24. from ..sync_methods.retrieval import parse_retrieval_event
  25. class RetrievalSDK:
  26. """Async SDK for interacting with documents in the v3 API."""
  27. def __init__(self, client):
  28. self.client = client
  29. async def search(
  30. self,
  31. query: str,
  32. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  33. search_settings: Optional[dict | SearchSettings] = None,
  34. ) -> WrappedSearchResponse:
  35. """
  36. Conduct a vector and/or graph search (async).
  37. Args:
  38. query (str): The search query.
  39. search_mode (Optional[str | SearchMode]): Search mode ('basic', 'advanced', 'custom'). Defaults to 'custom'.
  40. search_settings (Optional[dict | SearchSettings]): Search settings (filters, limits, hybrid options, etc.).
  41. Returns:
  42. WrappedSearchResponse: The search results.
  43. """
  44. if search_settings and not isinstance(search_settings, dict):
  45. search_settings = search_settings.model_dump()
  46. data: dict[str, Any] = {
  47. "query": query,
  48. "search_settings": search_settings,
  49. }
  50. if search_mode:
  51. data["search_mode"] = search_mode
  52. response_dict = await self.client._make_request(
  53. "POST",
  54. "retrieval/search",
  55. json=data,
  56. version="v3",
  57. )
  58. return WrappedSearchResponse(**response_dict)
  59. async def completion(
  60. self,
  61. messages: list[dict | Message],
  62. generation_config: Optional[dict | GenerationConfig] = None,
  63. ) -> WrappedLLMChatCompletion:
  64. """
  65. Get a completion from the model (async).
  66. Args:
  67. messages (list[dict | Message]): List of messages to generate completion for. Each message should have a 'role' and 'content'.
  68. generation_config (Optional[dict | GenerationConfig]): Configuration for text generation.
  69. Returns:
  70. WrappedLLMChatCompletion
  71. """
  72. cast_messages: list[Message] = [
  73. Message(**msg) if isinstance(msg, dict) else msg
  74. for msg in messages
  75. ]
  76. if generation_config and not isinstance(generation_config, dict):
  77. generation_config = generation_config.model_dump()
  78. data: dict[str, Any] = {
  79. "messages": [msg.model_dump() for msg in cast_messages],
  80. "generation_config": generation_config,
  81. }
  82. response_dict = await self.client._make_request(
  83. "POST",
  84. "retrieval/completion",
  85. json=data,
  86. version="v3",
  87. )
  88. return WrappedLLMChatCompletion(**response_dict)
  89. async def embedding(self, text: str) -> WrappedEmbeddingResponse:
  90. """Generate an embedding for given text.
  91. Args:
  92. text (str): Text to generate embeddings for.
  93. Returns:
  94. WrappedEmbeddingResponse
  95. """
  96. data: dict[str, Any] = {
  97. "text": text,
  98. }
  99. response_dict = await self.client._make_request(
  100. "POST",
  101. "retrieval/embedding",
  102. data=data,
  103. version="v3",
  104. )
  105. return WrappedEmbeddingResponse(**response_dict)
  106. async def rag(
  107. self,
  108. query: str,
  109. rag_generation_config: Optional[dict | GenerationConfig] = None,
  110. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  111. search_settings: Optional[dict | SearchSettings] = None,
  112. task_prompt: Optional[str] = None,
  113. include_title_if_available: Optional[bool] = False,
  114. include_web_search: Optional[bool] = False,
  115. ) -> (
  116. WrappedRAGResponse
  117. | AsyncGenerator[
  118. ThinkingEvent
  119. | SearchResultsEvent
  120. | MessageEvent
  121. | CitationEvent
  122. | FinalAnswerEvent
  123. | ToolCallEvent
  124. | ToolResultEvent
  125. | UnknownEvent
  126. | None,
  127. None,
  128. ]
  129. ):
  130. """Conducts a Retrieval Augmented Generation (RAG) search with the
  131. given query.
  132. Args:
  133. query (str): The query to search for.
  134. rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
  135. search_settings (Optional[dict | SearchSettings]): Vector search settings.
  136. task_prompt (Optional[str]): Task prompt override.
  137. include_title_if_available (Optional[bool]): Include the title if available.
  138. Returns:
  139. WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
  140. """
  141. if rag_generation_config and not isinstance(
  142. rag_generation_config, dict
  143. ):
  144. rag_generation_config = rag_generation_config.model_dump()
  145. if search_settings and not isinstance(search_settings, dict):
  146. search_settings = search_settings.model_dump()
  147. data: dict[str, Any] = {
  148. "query": query,
  149. "rag_generation_config": rag_generation_config,
  150. "search_settings": search_settings,
  151. "task_prompt": task_prompt,
  152. "include_title_if_available": include_title_if_available,
  153. "include_web_search": include_web_search,
  154. }
  155. if search_mode:
  156. data["search_mode"] = search_mode
  157. if rag_generation_config and rag_generation_config.get( # type: ignore
  158. "stream", False
  159. ):
  160. async def generate_events():
  161. raw_stream = await self.client._make_streaming_request(
  162. "POST",
  163. "retrieval/rag",
  164. json=data,
  165. version="v3",
  166. )
  167. async for response in raw_stream:
  168. yield parse_retrieval_event(response)
  169. return generate_events()
  170. response_dict = await self.client._make_request(
  171. "POST",
  172. "retrieval/rag",
  173. json=data,
  174. version="v3",
  175. )
  176. return WrappedRAGResponse(**response_dict)
  177. async def agent(
  178. self,
  179. message: Optional[dict | Message] = None,
  180. rag_generation_config: Optional[dict | GenerationConfig] = None,
  181. research_generation_config: Optional[dict | GenerationConfig] = None,
  182. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  183. search_settings: Optional[dict | SearchSettings] = None,
  184. task_prompt: Optional[str] = None,
  185. include_title_if_available: Optional[bool] = True,
  186. conversation_id: Optional[str | UUID] = None,
  187. max_tool_context_length: Optional[int] = None,
  188. use_system_context: Optional[bool] = True,
  189. rag_tools: Optional[list[str]] = None,
  190. research_tools: Optional[list[str]] = None,
  191. tools: Optional[list[str]] = None,
  192. mode: Optional[str] = "rag",
  193. needs_initial_conversation_name: Optional[bool] = None,
  194. ) -> (
  195. WrappedAgentResponse
  196. | AsyncGenerator[
  197. ThinkingEvent
  198. | SearchResultsEvent
  199. | MessageEvent
  200. | CitationEvent
  201. | FinalAnswerEvent
  202. | ToolCallEvent
  203. | ToolResultEvent
  204. | UnknownEvent
  205. | None,
  206. None,
  207. ]
  208. ):
  209. """
  210. Performs a single turn in a conversation with a RAG agent (async).
  211. May return a `WrappedAgentResponse` or a streaming generator if `stream=True`.
  212. Args:
  213. message (Optional[dict | Message]): Current message to process.
  214. messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead).
  215. rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
  216. research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
  217. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
  218. search_settings (Optional[dict | SearchSettings]): The search configuration object.
  219. task_prompt (Optional[str]): Optional custom prompt to override default.
  220. include_title_if_available (Optional[bool]): Include document titles from search results.
  221. conversation_id (Optional[str | UUID]): ID of the conversation.
  222. tools (Optional[list[str]]): List of tools to execute (deprecated).
  223. rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
  224. research_tools (Optional[list[str]]): List of tools to enable for Research mode.
  225. max_tool_context_length (Optional[int]): Maximum length of returned tool context.
  226. use_system_context (Optional[bool]): Use extended prompt for generation.
  227. mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'.
  228. Returns:
  229. Either a WrappedAgentResponse or an AsyncGenerator for streaming.
  230. """
  231. if rag_generation_config and not isinstance(
  232. rag_generation_config, dict
  233. ):
  234. rag_generation_config = rag_generation_config.model_dump()
  235. if research_generation_config and not isinstance(
  236. research_generation_config, dict
  237. ):
  238. research_generation_config = (
  239. research_generation_config.model_dump()
  240. )
  241. if search_settings and not isinstance(search_settings, dict):
  242. search_settings = search_settings.model_dump()
  243. data: dict[str, Any] = {
  244. "rag_generation_config": rag_generation_config or {},
  245. "search_settings": search_settings,
  246. "task_prompt": task_prompt,
  247. "include_title_if_available": include_title_if_available,
  248. "conversation_id": (
  249. str(conversation_id) if conversation_id else None
  250. ),
  251. "max_tool_context_length": max_tool_context_length,
  252. "use_system_context": use_system_context,
  253. "mode": mode,
  254. }
  255. # Handle generation configs based on mode
  256. if research_generation_config and mode == "research":
  257. data["research_generation_config"] = research_generation_config
  258. # Handle tool configurations
  259. if rag_tools:
  260. data["rag_tools"] = rag_tools
  261. if research_tools:
  262. data["research_tools"] = research_tools
  263. if tools: # Backward compatibility
  264. data["tools"] = tools
  265. if search_mode:
  266. data["search_mode"] = search_mode
  267. if needs_initial_conversation_name:
  268. data["needs_initial_conversation_name"] = (
  269. needs_initial_conversation_name
  270. )
  271. if message:
  272. cast_message: Message = (
  273. Message(**message) if isinstance(message, dict) else message
  274. )
  275. data["message"] = cast_message.model_dump()
  276. is_stream = False
  277. if mode != "research":
  278. if isinstance(rag_generation_config, dict):
  279. is_stream = rag_generation_config.get("stream", False)
  280. elif rag_generation_config is not None:
  281. is_stream = rag_generation_config.stream
  282. else:
  283. if research_generation_config:
  284. if isinstance(research_generation_config, dict):
  285. is_stream = research_generation_config.get( # type: ignore
  286. "stream", False
  287. )
  288. else:
  289. is_stream = research_generation_config.stream
  290. if is_stream:
  291. async def generate_events():
  292. raw_stream = await self.client._make_streaming_request(
  293. "POST",
  294. "retrieval/agent",
  295. json=data,
  296. version="v3",
  297. )
  298. async for response in raw_stream:
  299. yield parse_retrieval_event(response)
  300. return generate_events()
  301. response_dict = await self.client._make_request(
  302. "POST",
  303. "retrieval/agent",
  304. json=data,
  305. version="v3",
  306. )
  307. return WrappedAgentResponse(**response_dict)