sync_retrieval.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from __future__ import annotations # for Python 3.10+
  2. import logging
  3. from typing import AsyncGenerator, Optional
  4. from typing_extensions import deprecated
  5. from ..models import GenerationConfig, Message, RAGResponse, SearchSettings
  6. logger = logging.getLogger()
  7. class SyncRetrievalMixins:
  8. def search_documents(
  9. self,
  10. query: str,
  11. search_settings: Optional[dict | SearchSettings] = None,
  12. ):
  13. """
  14. Conduct a vector and/or KG search.
  15. Args:
  16. query (str): The query to search for.
  17. search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  18. Returns:
  19. SearchResponse: The search response.
  20. """
  21. if search_settings and not isinstance(search_settings, dict):
  22. search_settings = search_settings.model_dump()
  23. data = {
  24. "query": query,
  25. "search_settings": search_settings,
  26. }
  27. return self._make_request("POST", "search_documents", json=data) # type: ignore
  28. @deprecated("Use client.retrieval.search() instead")
  29. def search(
  30. self,
  31. query: str,
  32. search_settings: Optional[dict | SearchSettings] = None,
  33. ):
  34. """
  35. Conduct a vector and/or KG search.
  36. Args:
  37. query (str): The query to search for.
  38. search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  39. Returns:
  40. CombinedSearchResponse: The search response.
  41. """
  42. if search_settings and not isinstance(search_settings, dict):
  43. search_settings = search_settings.model_dump()
  44. data = {
  45. "query": query,
  46. "search_settings": search_settings,
  47. }
  48. return self._make_request("POST", "search", json=data) # type: ignore
  49. @deprecated("Use client.retrieval.completion() instead")
  50. def completion(
  51. self,
  52. messages: list[dict | Message],
  53. generation_config: Optional[dict | GenerationConfig] = None,
  54. ):
  55. cast_messages: list[Message] = [
  56. Message(**msg) if isinstance(msg, dict) else msg
  57. for msg in messages
  58. ]
  59. if generation_config and not isinstance(generation_config, dict):
  60. generation_config = generation_config.model_dump()
  61. data = {
  62. "messages": [msg.model_dump() for msg in cast_messages],
  63. "generation_config": generation_config,
  64. }
  65. return self._make_request("POST", "completion", json=data) # type: ignore
  66. @deprecated("Use client.retrieval.rag() instead")
  67. def rag(
  68. self,
  69. query: str,
  70. rag_generation_config: Optional[dict | GenerationConfig] = None,
  71. search_settings: Optional[dict | SearchSettings] = None,
  72. task_prompt_override: Optional[str] = None,
  73. include_title_if_available: Optional[bool] = False,
  74. ) -> RAGResponse | AsyncGenerator[RAGResponse, None]:
  75. """
  76. Conducts a Retrieval Augmented Generation (RAG) search with the given query.
  77. Args:
  78. query (str): The query to search for.
  79. rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
  80. search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  81. task_prompt_override (Optional[str]): Task prompt override.
  82. include_title_if_available (Optional[bool]): Include the title if available.
  83. Returns:
  84. Union[RAGResponse, AsyncGenerator[RAGResponse, None]]: The RAG response
  85. """
  86. if rag_generation_config and not isinstance(
  87. rag_generation_config, dict
  88. ):
  89. rag_generation_config = rag_generation_config.model_dump()
  90. if search_settings and not isinstance(search_settings, dict):
  91. search_settings = search_settings.model_dump()
  92. data = {
  93. "query": query,
  94. "rag_generation_config": rag_generation_config,
  95. "search_settings": search_settings,
  96. "task_prompt_override": task_prompt_override,
  97. "include_title_if_available": include_title_if_available,
  98. }
  99. if rag_generation_config and rag_generation_config.get( # type: ignore
  100. "stream", False
  101. ):
  102. return self._make_streaming_request("POST", "rag", json=data) # type: ignore
  103. else:
  104. return self._make_request("POST", "rag", json=data) # type: ignore
  105. @deprecated("Use client.retrieval.agent() instead")
  106. def agent(
  107. self,
  108. message: Optional[dict | Message] = None,
  109. rag_generation_config: Optional[dict | GenerationConfig] = None,
  110. search_settings: Optional[dict | SearchSettings] = None,
  111. task_prompt_override: Optional[str] = None,
  112. include_title_if_available: Optional[bool] = False,
  113. conversation_id: Optional[str] = None,
  114. # TODO - Deprecate messages
  115. messages: Optional[dict | Message] = None,
  116. ) -> list[Message] | AsyncGenerator[Message, None]:
  117. """
  118. Performs a single turn in a conversation with a RAG agent.
  119. Args:
  120. messages (List[Union[dict, Message]]): The messages to send to the agent.
  121. rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
  122. chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
  123. graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
  124. task_prompt_override (Optional[str]): Task prompt override.
  125. include_title_if_available (Optional[bool]): Include the title if available.
  126. Returns:
  127. Union[List[Message], AsyncGenerator[Message, None]]: The agent response.
  128. """
  129. if messages:
  130. logger.warning(
  131. "The `messages` argument is deprecated. Please use `message` instead."
  132. )
  133. if rag_generation_config and not isinstance(
  134. rag_generation_config, dict
  135. ):
  136. rag_generation_config = rag_generation_config.model_dump()
  137. if search_settings and not isinstance(search_settings, dict):
  138. search_settings = search_settings.model_dump()
  139. data = {
  140. "rag_generation_config": rag_generation_config or {},
  141. "search_settings": search_settings or {},
  142. "task_prompt_override": task_prompt_override,
  143. "include_title_if_available": include_title_if_available,
  144. "conversation_id": conversation_id,
  145. }
  146. if message:
  147. cast_message: Message = (
  148. Message(**message) if isinstance(message, dict) else message
  149. )
  150. data["message"] = cast_message.model_dump()
  151. if messages:
  152. data["messages"] = [
  153. (
  154. Message(**msg).model_dump() # type: ignore
  155. if isinstance(msg, dict)
  156. else msg.model_dump() # type: ignore
  157. )
  158. for msg in messages
  159. ]
  160. if rag_generation_config and rag_generation_config.get( # type: ignore
  161. "stream", False
  162. ):
  163. return self._make_streaming_request("POST", "agent", json=data) # type: ignore
  164. else:
  165. return self._make_request("POST", "agent", json=data) # type: ignore
  166. def embedding(
  167. self,
  168. content: str,
  169. ) -> list[float]:
  170. """
  171. Generate embeddings for the provided content.
  172. Args:
  173. content (str): The text content to embed.
  174. Returns:
  175. list[float]: The generated embedding vector.
  176. """
  177. return self._make_request("POST", "embedding", json=content) # type: ignore