retrieval.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from typing import AsyncGenerator, Optional
  2. from ..models import (
  3. CombinedSearchResponse,
  4. GenerationConfig,
  5. GraphSearchSettings,
  6. Message,
  7. RAGResponse,
  8. SearchMode,
  9. SearchSettings,
  10. )
  11. class RetrievalSDK:
  12. """
  13. SDK for interacting with documents in the v3 API.
  14. """
  15. def __init__(self, client):
  16. self.client = client
  17. async def search(
  18. self,
  19. query: str,
  20. search_mode: Optional[str | SearchMode] = "custom",
  21. search_settings: Optional[dict | SearchSettings] = None,
  22. ) -> CombinedSearchResponse:
  23. """
  24. Conduct a vector and/or KG search.
  25. Args:
  26. query (str): The query to search for.
  27. search_settings (Optional[dict, SearchSettings]]): Vector search settings.
  28. Returns:
  29. CombinedSearchResponse: The search response.
  30. """
  31. if search_mode and not isinstance(search_mode, str):
  32. search_mode = search_mode.value
  33. if search_settings and not isinstance(search_settings, dict):
  34. search_settings = search_settings.model_dump()
  35. data = {
  36. "query": query,
  37. "search_settings": search_settings,
  38. }
  39. if search_mode:
  40. data["search_mode"] = search_mode
  41. return await self.client._make_request(
  42. "POST",
  43. "retrieval/search",
  44. json=data,
  45. version="v3",
  46. )
  47. async def completion(
  48. self,
  49. messages: list[dict | Message],
  50. generation_config: Optional[dict | GenerationConfig] = None,
  51. ):
  52. cast_messages: list[Message] = [
  53. Message(**msg) if isinstance(msg, dict) else msg
  54. for msg in messages
  55. ]
  56. if generation_config and not isinstance(generation_config, dict):
  57. generation_config = generation_config.model_dump()
  58. data = {
  59. "messages": [msg.model_dump() for msg in cast_messages],
  60. "generation_config": generation_config,
  61. }
  62. return await self.client._make_request(
  63. "POST",
  64. "retrieval/completion",
  65. json=data,
  66. version="v3",
  67. )
  68. async def embedding(
  69. self,
  70. text: str,
  71. ):
  72. data = {
  73. "text": text,
  74. }
  75. return await self.client._make_request(
  76. "POST",
  77. "retrieval/embedding",
  78. data=data,
  79. version="v3",
  80. )
  81. async def rag(
  82. self,
  83. query: str,
  84. rag_generation_config: Optional[dict | GenerationConfig] = None,
  85. search_mode: Optional[str | SearchMode] = "custom",
  86. search_settings: Optional[dict | SearchSettings] = None,
  87. task_prompt_override: Optional[str] = None,
  88. include_title_if_available: Optional[bool] = False,
  89. ) -> RAGResponse | AsyncGenerator[RAGResponse, None]:
  90. """
  91. Conducts a Retrieval Augmented Generation (RAG) search with the given query.
  92. Args:
  93. query (str): The query to search for.
  94. rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
  95. search_settings (Optional[dict | SearchSettings]): Vector search settings.
  96. task_prompt_override (Optional[str]): Task prompt override.
  97. include_title_if_available (Optional[bool]): Include the title if available.
  98. Returns:
  99. RAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
  100. """
  101. if rag_generation_config and not isinstance(
  102. rag_generation_config, dict
  103. ):
  104. rag_generation_config = rag_generation_config.model_dump()
  105. if search_settings and not isinstance(search_settings, dict):
  106. search_settings = search_settings.model_dump()
  107. data = {
  108. "query": query,
  109. "rag_generation_config": rag_generation_config,
  110. "search_settings": search_settings,
  111. "task_prompt_override": task_prompt_override,
  112. "include_title_if_available": include_title_if_available,
  113. }
  114. if search_mode:
  115. data["search_mode"] = search_mode
  116. if rag_generation_config and rag_generation_config.get( # type: ignore
  117. "stream", False
  118. ):
  119. return self.client._make_streaming_request(
  120. "POST",
  121. "retrieval/rag",
  122. json=data,
  123. version="v3",
  124. )
  125. else:
  126. return await self.client._make_request(
  127. "POST",
  128. "retrieval/rag",
  129. json=data,
  130. version="v3",
  131. )
  132. async def agent(
  133. self,
  134. message: Optional[dict | Message] = None,
  135. rag_generation_config: Optional[dict | GenerationConfig] = None,
  136. search_mode: Optional[str | SearchMode] = "custom",
  137. search_settings: Optional[dict | SearchSettings] = None,
  138. task_prompt_override: Optional[str] = None,
  139. include_title_if_available: Optional[bool] = False,
  140. conversation_id: Optional[str] = None,
  141. ) -> list[Message] | AsyncGenerator[Message, None]:
  142. """
  143. Performs a single turn in a conversation with a RAG agent.
  144. Args:
  145. message (Optional[dict | Message]): The message to send to the agent.
  146. search_settings (Optional[dict | SearchSettings]): Vector search settings.
  147. task_prompt_override (Optional[str]): Task prompt override.
  148. include_title_if_available (Optional[bool]): Include the title if available.
  149. Returns:
  150. List[Message], AsyncGenerator[Message, None]]: The agent response.
  151. """
  152. if rag_generation_config and not isinstance(
  153. rag_generation_config, dict
  154. ):
  155. rag_generation_config = rag_generation_config.model_dump()
  156. if search_settings and not isinstance(search_settings, dict):
  157. search_settings = search_settings.model_dump()
  158. data = {
  159. "rag_generation_config": rag_generation_config or {},
  160. "search_settings": search_settings,
  161. "task_prompt_override": task_prompt_override,
  162. "include_title_if_available": include_title_if_available,
  163. "conversation_id": conversation_id,
  164. }
  165. if search_mode:
  166. data["search_mode"] = search_mode
  167. if message:
  168. cast_message: Message = (
  169. Message(**message) if isinstance(message, dict) else message
  170. )
  171. data["message"] = cast_message.model_dump()
  172. if rag_generation_config and rag_generation_config.get( # type: ignore
  173. "stream", False
  174. ):
  175. return self.client._make_streaming_request(
  176. "POST",
  177. "retrieval/agent",
  178. json=data,
  179. version="v3",
  180. )
  181. else:
  182. return await self.client._make_request(
  183. "POST",
  184. "retrieval/agent",
  185. json=data,
  186. version="v3",
  187. )