123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- from __future__ import annotations # for Python 3.10+
- import logging
- from typing import AsyncGenerator, Optional
- from typing_extensions import deprecated
- from ..models import (
- GenerationConfig,
- GraphSearchSettings,
- Message,
- RAGResponse,
- SearchSettings,
- )
- logger = logging.getLogger()
- class RetrievalMixins:
- async def search_documents(
- self,
- query: str,
- settings: Optional[dict] = None,
- ):
- """
- Conduct a vector and/or KG search.
- Args:
- query (str): The query to search for.
- chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
- graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
- Returns:
- SearchResponse: The search response.
- """
- if settings and not isinstance(settings, dict):
- settings = settings.model_dump()
- data = {
- "query": query,
- "settings": settings,
- }
- return await self._make_request("POST", "search_documents", json=data) # type: ignore
- @deprecated("Use client.retrieval.search() instead")
- async def search(
- self,
- query: str,
- chunk_search_settings: Optional[dict | SearchSettings] = None,
- graph_search_settings: Optional[dict | GraphSearchSettings] = None,
- ):
- """
- Conduct a vector and/or KG search.
- Args:
- query (str): The query to search for.
- chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
- graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
- Returns:
- CombinedSearchResponse: The search response.
- """
- if chunk_search_settings and not isinstance(
- chunk_search_settings, dict
- ):
- chunk_search_settings = chunk_search_settings.model_dump()
- if graph_search_settings and not isinstance(
- graph_search_settings, dict
- ):
- graph_search_settings = graph_search_settings.model_dump()
- data = {
- "query": query,
- "chunk_search_settings": chunk_search_settings,
- "graph_search_settings": graph_search_settings,
- }
- return await self._make_request("POST", "search", json=data) # type: ignore
- @deprecated("Use client.retrieval.completion() instead")
- async def completion(
- self,
- messages: list[dict | Message],
- generation_config: Optional[dict | GenerationConfig] = None,
- ):
- cast_messages: list[Message] = [
- Message(**msg) if isinstance(msg, dict) else msg
- for msg in messages
- ]
- if generation_config and not isinstance(generation_config, dict):
- generation_config = generation_config.model_dump()
- data = {
- "messages": [msg.model_dump() for msg in cast_messages],
- "generation_config": generation_config,
- }
- return await self._make_request("POST", "completion", json=data) # type: ignore
- @deprecated("Use client.retrieval.rag() instead")
- async def rag(
- self,
- query: str,
- rag_generation_config: Optional[dict | GenerationConfig] = None,
- chunk_search_settings: Optional[dict | SearchSettings] = None,
- graph_search_settings: Optional[dict | GraphSearchSettings] = None,
- task_prompt_override: Optional[str] = None,
- include_title_if_available: Optional[bool] = False,
- ) -> RAGResponse | AsyncGenerator[RAGResponse, None]:
- """
- Conducts a Retrieval Augmented Generation (RAG) search with the given query.
- Args:
- query (str): The query to search for.
- rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
- chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
- graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
- task_prompt_override (Optional[str]): Task prompt override.
- include_title_if_available (Optional[bool]): Include the title if available.
- Returns:
- Union[RAGResponse, AsyncGenerator[RAGResponse, None]]: The RAG response
- """
- if rag_generation_config and not isinstance(
- rag_generation_config, dict
- ):
- rag_generation_config = rag_generation_config.model_dump()
- if chunk_search_settings and not isinstance(
- chunk_search_settings, dict
- ):
- chunk_search_settings = chunk_search_settings.model_dump()
- if graph_search_settings and not isinstance(
- graph_search_settings, dict
- ):
- graph_search_settings = graph_search_settings.model_dump()
- data = {
- "query": query,
- "rag_generation_config": rag_generation_config,
- "chunk_search_settings": chunk_search_settings,
- "graph_search_settings": graph_search_settings,
- "task_prompt_override": task_prompt_override,
- "include_title_if_available": include_title_if_available,
- }
- if rag_generation_config and rag_generation_config.get( # type: ignore
- "stream", False
- ):
- return self._make_streaming_request("POST", "rag", json=data) # type: ignore
- else:
- return await self._make_request("POST", "rag", json=data) # type: ignore
- @deprecated("Use client.retrieval.agent() instead")
- async def agent(
- self,
- message: Optional[dict | Message] = None,
- rag_generation_config: Optional[dict | GenerationConfig] = None,
- chunk_search_settings: Optional[dict | SearchSettings] = None,
- graph_search_settings: Optional[dict | GraphSearchSettings] = None,
- task_prompt_override: Optional[str] = None,
- include_title_if_available: Optional[bool] = False,
- conversation_id: Optional[str] = None,
- branch_id: Optional[str] = None,
- # TODO - Deprecate messages
- messages: Optional[dict | Message] = None,
- ) -> list[Message] | AsyncGenerator[Message, None]:
- """
- Performs a single turn in a conversation with a RAG agent.
- Args:
- messages (List[Union[dict, Message]]): The messages to send to the agent.
- rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration.
- chunk_search_settings (Optional[Union[dict, SearchSettings]]): Vector search settings.
- graph_search_settings (Optional[Union[dict, GraphSearchSettings]]): KG search settings.
- task_prompt_override (Optional[str]): Task prompt override.
- include_title_if_available (Optional[bool]): Include the title if available.
- Returns:
- Union[List[Message], AsyncGenerator[Message, None]]: The agent response.
- """
- if messages:
- logger.warning(
- "The `messages` argument is deprecated. Please use `message` instead."
- )
- if rag_generation_config and not isinstance(
- rag_generation_config, dict
- ):
- rag_generation_config = rag_generation_config.model_dump()
- if chunk_search_settings and not isinstance(
- chunk_search_settings, dict
- ):
- chunk_search_settings = chunk_search_settings.model_dump()
- if graph_search_settings and not isinstance(
- graph_search_settings, dict
- ):
- graph_search_settings = graph_search_settings.model_dump()
- data = {
- "rag_generation_config": rag_generation_config or {},
- "chunk_search_settings": chunk_search_settings or {},
- "graph_search_settings": graph_search_settings,
- "task_prompt_override": task_prompt_override,
- "include_title_if_available": include_title_if_available,
- "conversation_id": conversation_id,
- "branch_id": branch_id,
- }
- if message:
- cast_message: Message = (
- Message(**message) if isinstance(message, dict) else message
- )
- data["message"] = cast_message.model_dump()
- if messages:
- data["messages"] = [
- (
- Message(**msg).model_dump() # type: ignore
- if isinstance(msg, dict)
- else msg.model_dump() # type: ignore
- )
- for msg in messages
- ]
- if rag_generation_config and rag_generation_config.get( # type: ignore
- "stream", False
- ):
- return self._make_streaming_request("POST", "agent", json=data) # type: ignore
- else:
- return await self._make_request("POST", "agent", json=data) # type: ignore
- @deprecated("Use client.retrieval.embedding() instead")
- async def embedding(
- self,
- content: str,
- ) -> list[float]:
- """
- Generate embeddings for the provided content.
- Args:
- content (str): The text content to embed.
- Returns:
- list[float]: The generated embedding vector.
- """
- return await self._make_request("POST", "embedding", json=content) # type: ignore
|