retrieval.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from __future__ import annotations # for Python 3.10+
  2. import logging
  3. from typing import AsyncGenerator, Optional
  4. from typing_extensions import deprecated
  5. from ..models import (
  6. GenerationConfig,
  7. GraphSearchSettings,
  8. Message,
  9. RAGResponse,
  10. SearchSettings,
  11. )
  12. logger = logging.getLogger()
  13. class RetrievalMixins:
  14. async def search_documents(
  15. self,
  16. query: str,
  17. settings: Optional[dict] = None,
  18. ):
  19. """
  20. Conduct a vector and/or KG search.
  21. Args:
  22. query (str): The query to search for.
  23. chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  24. graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
  25. Returns:
  26. SearchResponse: The search response.
  27. """
  28. if settings and not isinstance(settings, dict):
  29. settings = settings.model_dump()
  30. data = {
  31. "query": query,
  32. "settings": settings,
  33. }
  34. return await self._make_request("POST", "search_documents", json=data) # type: ignore
  35. @deprecated("Use client.retrieval.search() instead")
  36. async def search(
  37. self,
  38. query: str,
  39. chunk_search_settings: Optional[dict | SearchSettings] = None,
  40. graph_search_settings: Optional[dict | GraphSearchSettings] = None,
  41. ):
  42. """
  43. Conduct a vector and/or KG search.
  44. Args:
  45. query (str): The query to search for.
  46. chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  47. graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
  48. Returns:
  49. CombinedSearchResponse: The search response.
  50. """
  51. if chunk_search_settings and not isinstance(
  52. chunk_search_settings, dict
  53. ):
  54. chunk_search_settings = chunk_search_settings.model_dump()
  55. if graph_search_settings and not isinstance(
  56. graph_search_settings, dict
  57. ):
  58. graph_search_settings = graph_search_settings.model_dump()
  59. data = {
  60. "query": query,
  61. "chunk_search_settings": chunk_search_settings,
  62. "graph_search_settings": graph_search_settings,
  63. }
  64. return await self._make_request("POST", "search", json=data) # type: ignore
  65. @deprecated("Use client.retrieval.completion() instead")
  66. async def completion(
  67. self,
  68. messages: list[dict | Message],
  69. generation_config: Optional[dict | GenerationConfig] = None,
  70. ):
  71. cast_messages: list[Message] = [
  72. Message(**msg) if isinstance(msg, dict) else msg
  73. for msg in messages
  74. ]
  75. if generation_config and not isinstance(generation_config, dict):
  76. generation_config = generation_config.model_dump()
  77. data = {
  78. "messages": [msg.model_dump() for msg in cast_messages],
  79. "generation_config": generation_config,
  80. }
  81. return await self._make_request("POST", "completion", json=data) # type: ignore
  82. @deprecated("Use client.retrieval.rag() instead")
  83. async def rag(
  84. self,
  85. query: str,
  86. rag_generation_config: Optional[dict | GenerationConfig] = None,
  87. chunk_search_settings: Optional[dict | SearchSettings] = None,
  88. graph_search_settings: Optional[dict | GraphSearchSettings] = None,
  89. task_prompt_override: Optional[str] = None,
  90. include_title_if_available: Optional[bool] = False,
  91. ) -> RAGResponse | AsyncGenerator[RAGResponse, None]:
  92. """
  93. Conducts a Retrieval Augmented Generation (RAG) search with the given query.
  94. Args:
  95. query (str): The query to search for.
  96. rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
  97. chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  98. graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
  99. task_prompt_override (Optional[str]): Task prompt override.
  100. include_title_if_available (Optional[bool]): Include the title if available.
  101. Returns:
  102. Union[RAGResponse, AsyncGenerator[RAGResponse, None]]: The RAG response
  103. """
  104. if rag_generation_config and not isinstance(
  105. rag_generation_config, dict
  106. ):
  107. rag_generation_config = rag_generation_config.model_dump()
  108. if chunk_search_settings and not isinstance(
  109. chunk_search_settings, dict
  110. ):
  111. chunk_search_settings = chunk_search_settings.model_dump()
  112. if graph_search_settings and not isinstance(
  113. graph_search_settings, dict
  114. ):
  115. graph_search_settings = graph_search_settings.model_dump()
  116. data = {
  117. "query": query,
  118. "rag_generation_config": rag_generation_config,
  119. "chunk_search_settings": chunk_search_settings,
  120. "graph_search_settings": graph_search_settings,
  121. "task_prompt_override": task_prompt_override,
  122. "include_title_if_available": include_title_if_available,
  123. }
  124. if rag_generation_config and rag_generation_config.get( # type: ignore
  125. "stream", False
  126. ):
  127. return self._make_streaming_request("POST", "rag", json=data) # type: ignore
  128. else:
  129. return await self._make_request("POST", "rag", json=data) # type: ignore
  130. @deprecated("Use client.retrieval.agent() instead")
  131. async def agent(
  132. self,
  133. message: Optional[dict | Message] = None,
  134. rag_generation_config: Optional[dict | GenerationConfig] = None,
  135. chunk_search_settings: Optional[dict | SearchSettings] = None,
  136. graph_search_settings: Optional[dict | GraphSearchSettings] = None,
  137. task_prompt_override: Optional[str] = None,
  138. include_title_if_available: Optional[bool] = False,
  139. conversation_id: Optional[str] = None,
  140. branch_id: Optional[str] = None,
  141. # TODO - Deprecate messages
  142. messages: Optional[dict | Message] = None,
  143. ) -> list[Message] | AsyncGenerator[Message, None]:
  144. """
  145. Performs a single turn in a conversation with a RAG agent.
  146. Args:
  147. messages (List[Union[dict, Message]]): The messages to send to the agent.
  148. rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
  149. chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  150. graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
  151. task_prompt_override (Optional[str]): Task prompt override.
  152. include_title_if_available (Optional[bool]): Include the title if available.
  153. Returns:
  154. Union[List[Message], AsyncGenerator[Message, None]]: The agent response.
  155. """
  156. if messages:
  157. logger.warning(
  158. "The `messages` argument is deprecated. Please use `message` instead."
  159. )
  160. if rag_generation_config and not isinstance(
  161. rag_generation_config, dict
  162. ):
  163. rag_generation_config = rag_generation_config.model_dump()
  164. if chunk_search_settings and not isinstance(
  165. chunk_search_settings, dict
  166. ):
  167. chunk_search_settings = chunk_search_settings.model_dump()
  168. if graph_search_settings and not isinstance(
  169. graph_search_settings, dict
  170. ):
  171. graph_search_settings = graph_search_settings.model_dump()
  172. data = {
  173. "rag_generation_config": rag_generation_config or {},
  174. "chunk_search_settings": chunk_search_settings or {},
  175. "graph_search_settings": graph_search_settings,
  176. "task_prompt_override": task_prompt_override,
  177. "include_title_if_available": include_title_if_available,
  178. "conversation_id": conversation_id,
  179. "branch_id": branch_id,
  180. }
  181. if message:
  182. cast_message: Message = (
  183. Message(**message) if isinstance(message, dict) else message
  184. )
  185. data["message"] = cast_message.model_dump()
  186. if messages:
  187. data["messages"] = [
  188. (
  189. Message(**msg).model_dump() # type: ignore
  190. if isinstance(msg, dict)
  191. else msg.model_dump() # type: ignore
  192. )
  193. for msg in messages
  194. ]
  195. if rag_generation_config and rag_generation_config.get( # type: ignore
  196. "stream", False
  197. ):
  198. return self._make_streaming_request("POST", "agent", json=data) # type: ignore
  199. else:
  200. return await self._make_request("POST", "agent", json=data) # type: ignore
  201. @deprecated("Use client.retrieval.embedding() instead")
  202. async def embedding(
  203. self,
  204. content: str,
  205. ) -> list[float]:
  206. """
  207. Generate embeddings for the provided content.
  208. Args:
  209. content (str): The text content to embed.
  210. Returns:
  211. list[float]: The generated embedding vector.
  212. """
  213. return await self._make_request("POST", "embedding", json=content) # type: ignore