123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352 |
- from typing import Any, AsyncGenerator, Optional
- from uuid import UUID
- from shared.api.models import (
- WrappedAgentResponse,
- WrappedEmbeddingResponse,
- WrappedLLMChatCompletion,
- WrappedRAGResponse,
- WrappedSearchResponse,
- )
- from ..models import (
- CitationEvent,
- FinalAnswerEvent,
- GenerationConfig,
- Message,
- MessageEvent,
- SearchMode,
- SearchResultsEvent,
- SearchSettings,
- ThinkingEvent,
- ToolCallEvent,
- ToolResultEvent,
- UnknownEvent,
- )
- from ..sync_methods.retrieval import parse_retrieval_event
- class RetrievalSDK:
- """Async SDK for interacting with documents in the v3 API."""
- def __init__(self, client):
- self.client = client
- async def search(
- self,
- query: str,
- search_mode: Optional[str | SearchMode] = SearchMode.custom,
- search_settings: Optional[dict | SearchSettings] = None,
- ) -> WrappedSearchResponse:
- """
- Conduct a vector and/or graph search (async).
- Args:
- query (str): The search query.
- search_mode (Optional[str | SearchMode]): Search mode ('basic', 'advanced', 'custom'). Defaults to 'custom'.
- search_settings (Optional[dict | SearchSettings]): Search settings (filters, limits, hybrid options, etc.).
- Returns:
- WrappedSearchResponse: The search results.
- """
- if search_settings and not isinstance(search_settings, dict):
- search_settings = search_settings.model_dump()
- data: dict[str, Any] = {
- "query": query,
- "search_settings": search_settings,
- }
- if search_mode:
- data["search_mode"] = search_mode
- response_dict = await self.client._make_request(
- "POST",
- "retrieval/search",
- json=data,
- version="v3",
- )
- return WrappedSearchResponse(**response_dict)
- async def completion(
- self,
- messages: list[dict | Message],
- generation_config: Optional[dict | GenerationConfig] = None,
- ) -> WrappedLLMChatCompletion:
- """
- Get a completion from the model (async).
- Args:
- messages (list[dict | Message]): List of messages to generate completion for. Each message should have a 'role' and 'content'.
- generation_config (Optional[dict | GenerationConfig]): Configuration for text generation.
- Returns:
- WrappedLLMChatCompletion
- """
- cast_messages: list[Message] = [
- Message(**msg) if isinstance(msg, dict) else msg
- for msg in messages
- ]
- if generation_config and not isinstance(generation_config, dict):
- generation_config = generation_config.model_dump()
- data: dict[str, Any] = {
- "messages": [msg.model_dump() for msg in cast_messages],
- "generation_config": generation_config,
- }
- response_dict = await self.client._make_request(
- "POST",
- "retrieval/completion",
- json=data,
- version="v3",
- )
- return WrappedLLMChatCompletion(**response_dict)
- async def embedding(self, text: str) -> WrappedEmbeddingResponse:
- """Generate an embedding for given text.
- Args:
- text (str): Text to generate embeddings for.
- Returns:
- WrappedEmbeddingResponse
- """
- data: dict[str, Any] = {
- "text": text,
- }
- response_dict = await self.client._make_request(
- "POST",
- "retrieval/embedding",
- data=data,
- version="v3",
- )
- return WrappedEmbeddingResponse(**response_dict)
- async def rag(
- self,
- query: str,
- rag_generation_config: Optional[dict | GenerationConfig] = None,
- search_mode: Optional[str | SearchMode] = SearchMode.custom,
- search_settings: Optional[dict | SearchSettings] = None,
- task_prompt: Optional[str] = None,
- include_title_if_available: Optional[bool] = False,
- include_web_search: Optional[bool] = False,
- ) -> (
- WrappedRAGResponse
- | AsyncGenerator[
- ThinkingEvent
- | SearchResultsEvent
- | MessageEvent
- | CitationEvent
- | FinalAnswerEvent
- | ToolCallEvent
- | ToolResultEvent
- | UnknownEvent
- | None,
- None,
- ]
- ):
- """Conducts a Retrieval Augmented Generation (RAG) search with the
- given query.
- Args:
- query (str): The query to search for.
- rag_generation_config (Optional[dict | GenerationConfig]): RAG generation configuration.
- search_settings (Optional[dict | SearchSettings]): Vector search settings.
- task_prompt (Optional[str]): Task prompt override.
- include_title_if_available (Optional[bool]): Include the title if available.
- Returns:
- WrappedRAGResponse | AsyncGenerator[RAGResponse, None]: The RAG response
- """
- if rag_generation_config and not isinstance(
- rag_generation_config, dict
- ):
- rag_generation_config = rag_generation_config.model_dump()
- if search_settings and not isinstance(search_settings, dict):
- search_settings = search_settings.model_dump()
- data: dict[str, Any] = {
- "query": query,
- "rag_generation_config": rag_generation_config,
- "search_settings": search_settings,
- "task_prompt": task_prompt,
- "include_title_if_available": include_title_if_available,
- "include_web_search": include_web_search,
- }
- if search_mode:
- data["search_mode"] = search_mode
- if rag_generation_config and rag_generation_config.get( # type: ignore
- "stream", False
- ):
- async def generate_events():
- raw_stream = await self.client._make_streaming_request(
- "POST",
- "retrieval/rag",
- json=data,
- version="v3",
- )
- async for response in raw_stream:
- yield parse_retrieval_event(response)
- return generate_events()
- response_dict = await self.client._make_request(
- "POST",
- "retrieval/rag",
- json=data,
- version="v3",
- )
- return WrappedRAGResponse(**response_dict)
- async def agent(
- self,
- message: Optional[dict | Message] = None,
- rag_generation_config: Optional[dict | GenerationConfig] = None,
- research_generation_config: Optional[dict | GenerationConfig] = None,
- search_mode: Optional[str | SearchMode] = SearchMode.custom,
- search_settings: Optional[dict | SearchSettings] = None,
- task_prompt: Optional[str] = None,
- include_title_if_available: Optional[bool] = True,
- conversation_id: Optional[str | UUID] = None,
- max_tool_context_length: Optional[int] = None,
- use_system_context: Optional[bool] = True,
- rag_tools: Optional[list[str]] = None,
- research_tools: Optional[list[str]] = None,
- tools: Optional[list[str]] = None,
- mode: Optional[str] = "rag",
- needs_initial_conversation_name: Optional[bool] = None,
- ) -> (
- WrappedAgentResponse
- | AsyncGenerator[
- ThinkingEvent
- | SearchResultsEvent
- | MessageEvent
- | CitationEvent
- | FinalAnswerEvent
- | ToolCallEvent
- | ToolResultEvent
- | UnknownEvent
- | None,
- None,
- ]
- ):
- """
- Performs a single turn in a conversation with a RAG agent (async).
- May return a `WrappedAgentResponse` or a streaming generator if `stream=True`.
- Args:
- message (Optional[dict | Message]): Current message to process.
- messages (Optional[list[dict | Message]]): List of messages (deprecated, use message instead).
- rag_generation_config (Optional[dict | GenerationConfig]): Configuration for RAG generation in 'rag' mode.
- research_generation_config (Optional[dict | GenerationConfig]): Configuration for generation in 'research' mode.
- search_mode (Optional[str | SearchMode]): Pre-configured search modes: "basic", "advanced", or "custom".
- search_settings (Optional[dict | SearchSettings]): The search configuration object.
- task_prompt (Optional[str]): Optional custom prompt to override default.
- include_title_if_available (Optional[bool]): Include document titles from search results.
- conversation_id (Optional[str | UUID]): ID of the conversation.
- tools (Optional[list[str]]): List of tools to execute (deprecated).
- rag_tools (Optional[list[str]]): List of tools to enable for RAG mode.
- research_tools (Optional[list[str]]): List of tools to enable for Research mode.
- max_tool_context_length (Optional[int]): Maximum length of returned tool context.
- use_system_context (Optional[bool]): Use extended prompt for generation.
- mode (Optional[Literal["rag", "research"]]): Mode to use for generation: 'rag' or 'research'.
- Returns:
- Either a WrappedAgentResponse or an AsyncGenerator for streaming.
- """
- if rag_generation_config and not isinstance(
- rag_generation_config, dict
- ):
- rag_generation_config = rag_generation_config.model_dump()
- if research_generation_config and not isinstance(
- research_generation_config, dict
- ):
- research_generation_config = (
- research_generation_config.model_dump()
- )
- if search_settings and not isinstance(search_settings, dict):
- search_settings = search_settings.model_dump()
- data: dict[str, Any] = {
- "rag_generation_config": rag_generation_config or {},
- "search_settings": search_settings,
- "task_prompt": task_prompt,
- "include_title_if_available": include_title_if_available,
- "conversation_id": (
- str(conversation_id) if conversation_id else None
- ),
- "max_tool_context_length": max_tool_context_length,
- "use_system_context": use_system_context,
- "mode": mode,
- }
- # Handle generation configs based on mode
- if research_generation_config and mode == "research":
- data["research_generation_config"] = research_generation_config
- # Handle tool configurations
- if rag_tools:
- data["rag_tools"] = rag_tools
- if research_tools:
- data["research_tools"] = research_tools
- if tools: # Backward compatibility
- data["tools"] = tools
- if search_mode:
- data["search_mode"] = search_mode
- if needs_initial_conversation_name:
- data["needs_initial_conversation_name"] = (
- needs_initial_conversation_name
- )
- if message:
- cast_message: Message = (
- Message(**message) if isinstance(message, dict) else message
- )
- data["message"] = cast_message.model_dump()
- is_stream = False
- if mode != "research":
- if isinstance(rag_generation_config, dict):
- is_stream = rag_generation_config.get("stream", False)
- elif rag_generation_config is not None:
- is_stream = rag_generation_config.stream
- else:
- if research_generation_config:
- if isinstance(research_generation_config, dict):
- is_stream = research_generation_config.get( # type: ignore
- "stream", False
- )
- else:
- is_stream = research_generation_config.stream
- if is_stream:
- async def generate_events():
- raw_stream = await self.client._make_streaming_request(
- "POST",
- "retrieval/agent",
- json=data,
- version="v3",
- )
- async for response in raw_stream:
- yield parse_retrieval_event(response)
- return generate_events()
- response_dict = await self.client._make_request(
- "POST",
- "retrieval/agent",
- json=data,
- version="v3",
- )
- return WrappedAgentResponse(**response_dict)
|