retrieval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import json
  2. from typing import Any, Generator, Optional
  3. from uuid import UUID
  4. from shared.api.models import (
  5. WrappedAgentResponse,
  6. WrappedEmbeddingResponse,
  7. WrappedLLMChatCompletion,
  8. WrappedRAGResponse,
  9. WrappedSearchResponse,
  10. )
  11. from ..models import (
  12. AgentEvent,
  13. CitationData,
  14. CitationEvent,
  15. Delta,
  16. DeltaPayload,
  17. FinalAnswerData,
  18. FinalAnswerEvent,
  19. GenerationConfig,
  20. Message,
  21. MessageData,
  22. MessageDelta,
  23. MessageEvent,
  24. SearchMode,
  25. SearchResultsData,
  26. SearchResultsEvent,
  27. SearchSettings,
  28. ThinkingData,
  29. ThinkingEvent,
  30. ToolCallData,
  31. ToolCallEvent,
  32. ToolResultData,
  33. ToolResultEvent,
  34. UnknownEvent,
  35. )
  36. def parse_retrieval_event(raw: dict) -> Optional[AgentEvent]:
  37. """
  38. Convert a raw SSE event dict into a typed Pydantic model.
  39. Example raw dict:
  40. {
  41. "event": "message",
  42. "data": "{\"id\": \"msg_partial\", \"object\": \"agent.message.delta\", \"delta\": {...}}"
  43. }
  44. """
  45. event_type = raw.get("event", "unknown")
  46. # If event_type == "done", we usually return None to signal the SSE stream is finished.
  47. if event_type == "done":
  48. return None
  49. # The SSE "data" is JSON-encoded, so parse it
  50. data_str = raw.get("data", "")
  51. try:
  52. data_obj = json.loads(data_str)
  53. except json.JSONDecodeError as e:
  54. # You can decide whether to raise or return UnknownEvent
  55. raise ValueError(f"Could not parse JSON in SSE event data: {e}") from e
  56. # Now branch on event_type to build the right Pydantic model
  57. if event_type == "search_results":
  58. return SearchResultsEvent(
  59. event=event_type,
  60. data=SearchResultsData(**data_obj),
  61. )
  62. elif event_type == "message":
  63. # Parse nested delta structure manually before creating MessageData
  64. if "delta" in data_obj and isinstance(data_obj["delta"], dict):
  65. delta_dict = data_obj["delta"]
  66. # Convert content items to MessageDelta objects
  67. if "content" in delta_dict and isinstance(
  68. delta_dict["content"], list
  69. ):
  70. parsed_content = []
  71. for item in delta_dict["content"]:
  72. if isinstance(item, dict):
  73. # Parse payload to DeltaPayload
  74. if "payload" in item and isinstance(
  75. item["payload"], dict
  76. ):
  77. payload_dict = item["payload"]
  78. item["payload"] = DeltaPayload(**payload_dict)
  79. parsed_content.append(MessageDelta(**item))
  80. # Replace with parsed content
  81. delta_dict["content"] = parsed_content
  82. # Create properly typed Delta object
  83. data_obj["delta"] = Delta(**delta_dict)
  84. return MessageEvent(
  85. event=event_type,
  86. data=MessageData(**data_obj),
  87. )
  88. elif event_type == "citation":
  89. return CitationEvent(event=event_type, data=CitationData(**data_obj))
  90. elif event_type == "tool_call":
  91. return ToolCallEvent(event=event_type, data=ToolCallData(**data_obj))
  92. elif event_type == "tool_result":
  93. return ToolResultEvent(
  94. event=event_type, data=ToolResultData(**data_obj)
  95. )
  96. elif event_type == "thinking":
  97. # Parse nested delta structure manually before creating ThinkingData
  98. if "delta" in data_obj and isinstance(data_obj["delta"], dict):
  99. delta_dict = data_obj["delta"]
  100. # Convert content items to MessageDelta objects
  101. if "content" in delta_dict and isinstance(
  102. delta_dict["content"], list
  103. ):
  104. parsed_content = []
  105. for item in delta_dict["content"]:
  106. if isinstance(item, dict):
  107. # Parse payload to DeltaPayload
  108. if "payload" in item and isinstance(
  109. item["payload"], dict
  110. ):
  111. payload_dict = item["payload"]
  112. item["payload"] = DeltaPayload(**payload_dict)
  113. parsed_content.append(MessageDelta(**item))
  114. # Replace with parsed content
  115. delta_dict["content"] = parsed_content
  116. # Create properly typed Delta object
  117. data_obj["delta"] = Delta(**delta_dict)
  118. return ThinkingEvent(
  119. event=event_type,
  120. data=ThinkingData(**data_obj),
  121. )
  122. elif event_type == "final_answer":
  123. return FinalAnswerEvent(
  124. event=event_type, data=FinalAnswerData(**data_obj)
  125. )
  126. else:
  127. # Fallback if it doesn't match any known event
  128. return UnknownEvent(
  129. event=event_type,
  130. data=data_obj,
  131. )
  132. class RetrievalSDK:
  133. """SDK for interacting with documents in the v3 API."""
  134. def __init__(self, client):
  135. self.client = client
  136. def search(
  137. self,
  138. query: str,
  139. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  140. search_settings: Optional[dict | SearchSettings] = None,
  141. ) -> WrappedSearchResponse:
  142. """Conduct a vector and/or graph search.
  143. Args:
  144. query (str): The search query.
  145. search_mode (Optional[str | SearchMode]): Search mode ('basic', 'advanced', 'custom'). Defaults to 'custom'.
  146. search_settings (Optional[dict | SearchSettings]): Search settings (filters, limits, hybrid options, etc.).
  147. Returns:
  148. WrappedSearchResponse
  149. """
  150. if search_settings and not isinstance(search_settings, dict):
  151. search_settings = search_settings.model_dump()
  152. data: dict[str, Any] = {
  153. "query": query,
  154. "search_settings": search_settings,
  155. }
  156. if search_mode:
  157. data["search_mode"] = search_mode
  158. response_dict = self.client._make_request(
  159. "POST",
  160. "retrieval/search",
  161. json=data,
  162. version="v3",
  163. )
  164. return WrappedSearchResponse(**response_dict)
  165. def completion(
  166. self,
  167. messages: list[dict | Message],
  168. generation_config: Optional[dict | GenerationConfig] = None,
  169. ) -> WrappedLLMChatCompletion:
  170. """
  171. Get a completion from the model (async).
  172. Args:
  173. messages (list[dict | Message]): List of messages to generate completion for. Each message should have a 'role' and 'content'.
  174. generation_config (Optional[dict | GenerationConfig]): Configuration for text generation.
  175. Returns:
  176. WrappedLLMChatCompletion
  177. """
  178. cast_messages: list[Message] = [
  179. Message(**msg) if isinstance(msg, dict) else msg
  180. for msg in messages
  181. ]
  182. if generation_config and not isinstance(generation_config, dict):
  183. generation_config = generation_config.model_dump()
  184. data: dict[str, Any] = {
  185. "messages": [msg.model_dump() for msg in cast_messages],
  186. "generation_config": generation_config,
  187. }
  188. response_dict = self.client._make_request(
  189. "POST",
  190. "retrieval/completion",
  191. json=data,
  192. version="v3",
  193. )
  194. return WrappedLLMChatCompletion(**response_dict)
  195. def embedding(self, text: str) -> WrappedEmbeddingResponse:
  196. """Generate an embedding for given text.
  197. Args:
  198. text (str): Text to generate embeddings for.
  199. Returns:
  200. WrappedEmbeddingResponse
  201. """
  202. data: dict[str, Any] = {
  203. "text": text,
  204. }
  205. response_dict = self.client._make_request(
  206. "POST",
  207. "retrieval/embedding",
  208. data=data,
  209. version="v3",
  210. )
  211. return WrappedEmbeddingResponse(**response_dict)
  212. def rag(
  213. self,
  214. query: str,
  215. rag_generation_config: Optional[dict | GenerationConfig] = None,
  216. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  217. search_settings: Optional[dict | SearchSettings] = None,
  218. task_prompt: Optional[str] = None,
  219. include_title_if_available: Optional[bool] = False,
  220. include_web_search: Optional[bool] = False,
  221. ) -> (
  222. WrappedRAGResponse
  223. | Generator[
  224. ThinkingEvent
  225. | SearchResultsEvent
  226. | MessageEvent
  227. | CitationEvent
  228. | FinalAnswerEvent
  229. | ToolCallEvent
  230. | ToolResultEvent
  231. | UnknownEvent
  232. | None,
  233. None,
  234. None,
  235. ]
  236. ):
  237. """Conducts a Retrieval Augmented Generation (RAG) search with the
  238. given query.
  239. Args:
  240. query (str): The query to search for.
  241. rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
  242. search_settings (Optional[dict | SearchSettings]): Vector search settings.
  243. task_prompt (Optional[str]): Task prompt override.
  244. include_title_if_available (Optional[bool]): Include the title if available.
  245. Returns:
  246. WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
  247. """
  248. if rag_generation_config and not isinstance(
  249. rag_generation_config, dict
  250. ):
  251. rag_generation_config = rag_generation_config.model_dump()
  252. if search_settings and not isinstance(search_settings, dict):
  253. search_settings = search_settings.model_dump()
  254. data: dict[str, Any] = {
  255. "query": query,
  256. "rag_generation_config": rag_generation_config,
  257. "search_settings": search_settings,
  258. "task_prompt": task_prompt,
  259. "include_title_if_available": include_title_if_available,
  260. "include_web_search": include_web_search,
  261. }
  262. if search_mode:
  263. data["search_mode"] = search_mode
  264. if rag_generation_config and rag_generation_config.get( # type: ignore
  265. "stream", False
  266. ):
  267. raw_stream = self.client._make_streaming_request(
  268. "POST",
  269. "retrieval/rag",
  270. json=data,
  271. version="v3",
  272. )
  273. # Wrap the raw stream to parse each event
  274. return (parse_retrieval_event(event) for event in raw_stream)
  275. response_dict = self.client._make_request(
  276. "POST",
  277. "retrieval/rag",
  278. json=data,
  279. version="v3",
  280. )
  281. return WrappedRAGResponse(**response_dict)
  282. def agent(
  283. self,
  284. message: Optional[dict | Message] = None,
  285. rag_generation_config: Optional[dict | GenerationConfig] = None,
  286. research_generation_config: Optional[dict | GenerationConfig] = None,
  287. search_mode: Optional[str | SearchMode] = SearchMode.custom,
  288. search_settings: Optional[dict | SearchSettings] = None,
  289. task_prompt: Optional[str] = None,
  290. include_title_if_available: Optional[bool] = True,
  291. conversation_id: Optional[str | UUID] = None,
  292. max_tool_context_length: Optional[int] = None,
  293. use_system_context: Optional[bool] = True,
  294. rag_tools: Optional[list[str]] = None,
  295. research_tools: Optional[list[str]] = None,
  296. tools: Optional[list[str]] = None,
  297. mode: Optional[str] = "rag",
  298. needs_initial_conversation_name: Optional[bool] = None,
  299. ) -> (
  300. WrappedAgentResponse
  301. | Generator[
  302. ThinkingEvent
  303. | SearchResultsEvent
  304. | MessageEvent
  305. | CitationEvent
  306. | FinalAnswerEvent
  307. | ToolCallEvent
  308. | ToolResultEvent
  309. | UnknownEvent
  310. | None,
  311. None,
  312. None,
  313. ]
  314. ):
  315. """Performs a single turn in a conversation with a RAG agent.
  316. Args:
  317. message (Optional[dict | Message]): The message to send to the agent.
  318. rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
  319. research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
  320. search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
  321. search_settings (Optional[dict | SearchSettings]): Vector search settings.
  322. task_prompt (Optional[str]): Task prompt override.
  323. include_title_if_available (Optional[bool]): Include the title if available.
  324. conversation_id (Optional[str | UUID]): ID of the conversation for maintaining context.
  325. max_tool_context_length (Optional[int]): Maximum context length for tool replies.
  326. use_system_context (Optional[bool]): Whether to use system context in the prompt.
  327. rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
  328. Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions".
  329. research_tools (Optional[list[str]]): List of tools to enable for Research mode.
  330. Available tools: "rag", "reasoning", "critique", "python_executor".
  331. tools (Optional[list[str]]): Deprecated. List of tools to execute.
  332. mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis.
  333. Defaults to "rag".
  334. Returns:
  335. WrappedAgentResponse | AsyncGenerator[AgentEvent, None]: The agent response.
  336. """
  337. if rag_generation_config and not isinstance(
  338. rag_generation_config, dict
  339. ):
  340. rag_generation_config = rag_generation_config.model_dump()
  341. if research_generation_config and not isinstance(
  342. research_generation_config, dict
  343. ):
  344. research_generation_config = (
  345. research_generation_config.model_dump()
  346. )
  347. if search_settings and not isinstance(search_settings, dict):
  348. search_settings = search_settings.model_dump()
  349. data: dict[str, Any] = {
  350. "rag_generation_config": rag_generation_config or {},
  351. "search_settings": search_settings,
  352. "task_prompt": task_prompt,
  353. "include_title_if_available": include_title_if_available,
  354. "conversation_id": (
  355. str(conversation_id) if conversation_id else None
  356. ),
  357. "max_tool_context_length": max_tool_context_length,
  358. "use_system_context": use_system_context,
  359. "mode": mode,
  360. }
  361. # Handle generation configs based on mode
  362. if research_generation_config and mode == "research":
  363. data["research_generation_config"] = research_generation_config
  364. # Handle tool configurations
  365. if rag_tools:
  366. data["rag_tools"] = rag_tools
  367. if research_tools:
  368. data["research_tools"] = research_tools
  369. if tools: # Backward compatibility
  370. data["tools"] = tools
  371. if search_mode:
  372. data["search_mode"] = search_mode
  373. if needs_initial_conversation_name:
  374. data["needs_initial_conversation_name"] = (
  375. needs_initial_conversation_name
  376. )
  377. if message:
  378. cast_message: Message = (
  379. Message(**message) if isinstance(message, dict) else message
  380. )
  381. data["message"] = cast_message.model_dump()
  382. is_stream = False
  383. if mode != "research":
  384. if isinstance(rag_generation_config, dict):
  385. is_stream = rag_generation_config.get("stream", False)
  386. elif rag_generation_config is not None:
  387. is_stream = rag_generation_config.stream
  388. else:
  389. if research_generation_config:
  390. if isinstance(research_generation_config, dict):
  391. is_stream = research_generation_config.get( # type: ignore
  392. "stream", False
  393. )
  394. else:
  395. is_stream = research_generation_config.stream
  396. if is_stream:
  397. raw_stream = self.client._make_streaming_request(
  398. "POST",
  399. "retrieval/agent",
  400. json=data,
  401. version="v3",
  402. )
  403. return (parse_retrieval_event(event) for event in raw_stream)
  404. response_dict = self.client._make_request(
  405. "POST",
  406. "retrieval/agent",
  407. json=data,
  408. version="v3",
  409. )
  410. return WrappedAgentResponse(**response_dict)