123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 |
- import json
- import logging
- import time
- from typing import Optional
- from uuid import UUID
- from fastapi import HTTPException
- from core import R2RStreamingRAGAgent
- from core.base import (
- DocumentResponse,
- GenerationConfig,
- Message,
- R2RException,
- RunManager,
- SearchSettings,
- manage_run,
- to_async_generator,
- )
- from core.base.api.models import CombinedSearchResponse, RAGResponse, User
- from core.telemetry.telemetry_decorator import telemetry_event
- from shared.api.models.management.responses import MessageResponse
- from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
- from ..config import R2RConfig
- from .base import Service
- logger = logging.getLogger()
- class RetrievalService(Service):
- def __init__(
- self,
- config: R2RConfig,
- providers: R2RProviders,
- pipes: R2RPipes,
- pipelines: R2RPipelines,
- agents: R2RAgents,
- run_manager: RunManager,
- ):
- super().__init__(
- config,
- providers,
- pipes,
- pipelines,
- agents,
- run_manager,
- )
- @telemetry_event("Search")
- async def search( # TODO - rename to 'search_chunks'
- self,
- query: str,
- search_settings: SearchSettings = SearchSettings(),
- *args,
- **kwargs,
- ) -> CombinedSearchResponse:
- async with manage_run(self.run_manager) as run_id:
- t0 = time.time()
- if (
- search_settings.use_semantic_search
- and self.config.database.provider is None
- ):
- raise R2RException(
- status_code=400,
- message="Vector search is not enabled in the configuration.",
- )
- if (
- (
- search_settings.use_semantic_search
- and search_settings.use_fulltext_search
- )
- or search_settings.use_hybrid_search
- ) and not search_settings.hybrid_settings:
- raise R2RException(
- status_code=400,
- message="Hybrid search settings must be specified in the input configuration.",
- )
- # TODO - Remove these transforms once we have a better way to handle this
- for filter, value in search_settings.filters.items():
- if isinstance(value, UUID):
- search_settings.filters[filter] = str(value)
- merged_kwargs = {
- "input": to_async_generator([query]),
- "state": None,
- "search_settings": search_settings,
- "run_manager": self.run_manager,
- **kwargs,
- }
- results = await self.pipelines.search_pipeline.run(
- *args,
- **merged_kwargs,
- )
- t1 = time.time()
- latency = f"{t1 - t0:.2f}"
- return results.as_dict()
- @telemetry_event("SearchDocuments")
- async def search_documents(
- self,
- query: str,
- settings: SearchSettings,
- query_embedding: Optional[list[float]] = None,
- ) -> list[DocumentResponse]:
- return (
- await self.providers.database.documents_handler.search_documents(
- query_text=query,
- settings=settings,
- query_embedding=query_embedding,
- )
- )
- @telemetry_event("Completion")
- async def completion(
- self,
- messages: list[dict],
- generation_config: GenerationConfig,
- *args,
- **kwargs,
- ):
- return await self.providers.llm.aget_completion(
- [message.to_dict() for message in messages],
- generation_config,
- *args,
- **kwargs,
- )
- @telemetry_event("Embedding")
- async def embedding(
- self,
- text: str,
- ):
- return await self.providers.embedding.async_get_embedding(text=text)
- @telemetry_event("RAG")
- async def rag(
- self,
- query: str,
- rag_generation_config: GenerationConfig,
- search_settings: SearchSettings = SearchSettings(),
- *args,
- **kwargs,
- ) -> RAGResponse:
- async with manage_run(self.run_manager) as run_id:
- try:
- # TODO - Remove these transforms once we have a better way to handle this
- for (
- filter,
- value,
- ) in search_settings.filters.items():
- if isinstance(value, UUID):
- search_settings.filters[filter] = str(value)
- if rag_generation_config.stream:
- return await self.stream_rag_response(
- query,
- rag_generation_config,
- search_settings,
- *args,
- **kwargs,
- )
- merged_kwargs = {
- "input": to_async_generator([query]),
- "state": None,
- "search_settings": search_settings,
- "run_manager": self.run_manager,
- "rag_generation_config": rag_generation_config,
- **kwargs,
- }
- results = await self.pipelines.rag_pipeline.run(
- *args,
- **merged_kwargs,
- )
- if len(results) == 0:
- raise R2RException(
- status_code=404, message="No results found"
- )
- if len(results) > 1:
- logger.warning(
- f"Multiple results found for query: {query}"
- )
- # unpack the first result
- return results[0]
- except Exception as e:
- logger.error(f"Pipeline error: {str(e)}")
- if "NoneType" in str(e):
- raise HTTPException(
- status_code=502,
- detail="Remote server not reachable or returned an invalid response",
- ) from e
- raise HTTPException(
- status_code=500, detail="Internal Server Error"
- ) from e
- async def stream_rag_response(
- self,
- query,
- rag_generation_config,
- search_settings,
- *args,
- **kwargs,
- ):
- async def stream_response():
- async with manage_run(self.run_manager, "rag"):
- merged_kwargs = {
- "input": to_async_generator([query]),
- "state": None,
- "run_manager": self.run_manager,
- "search_settings": search_settings,
- "rag_generation_config": rag_generation_config,
- **kwargs,
- }
- async for (
- chunk
- ) in await self.pipelines.streaming_rag_pipeline.run(
- *args,
- **merged_kwargs,
- ):
- yield chunk
- return stream_response()
- @telemetry_event("Agent")
- async def agent(
- self,
- rag_generation_config: GenerationConfig,
- search_settings: SearchSettings = SearchSettings(),
- task_prompt_override: Optional[str] = None,
- include_title_if_available: Optional[bool] = False,
- conversation_id: Optional[UUID] = None,
- message: Optional[Message] = None,
- messages: Optional[list[Message]] = None,
- ):
- async with manage_run(self.run_manager) as run_id:
- try:
- if message and messages:
- raise R2RException(
- status_code=400,
- message="Only one of message or messages should be provided",
- )
- if not message and not messages:
- raise R2RException(
- status_code=400,
- message="Either message or messages should be provided",
- )
- # Ensure 'message' is a Message instance
- if message and not isinstance(message, Message):
- if isinstance(message, dict):
- message = Message.from_dict(message)
- else:
- raise R2RException(
- status_code=400,
- message="""
- Invalid message format. The expected format contains:
- role: MessageType | 'system' | 'user' | 'assistant' | 'function'
- content: Optional[str]
- name: Optional[str]
- function_call: Optional[dict[str, Any]]
- tool_calls: Optional[list[dict[str, Any]]]
- """,
- )
- # Ensure 'messages' is a list of Message instances
- if messages:
- processed_messages = []
- for message in messages:
- if isinstance(message, Message):
- processed_messages.append(message)
- elif hasattr(message, "dict"):
- processed_messages.append(
- Message.from_dict(message.dict())
- )
- elif isinstance(message, dict):
- processed_messages.append(
- Message.from_dict(message)
- )
- else:
- processed_messages.append(
- Message.from_dict(str(message))
- )
- messages = processed_messages
- else:
- messages = []
- # Transform UUID filters to strings
- for filter_key, value in search_settings.filters.items():
- if isinstance(value, UUID):
- search_settings.filters[filter_key] = str(value)
- ids = []
- if conversation_id: # Fetch the existing conversation
- try:
- conversation_messages = await self.providers.database.conversations_handler.get_conversation(
- conversation_id=conversation_id,
- )
- except Exception as e:
- logger.error(f"Error fetching conversation: {str(e)}")
- if conversation_messages is not None:
- messages_from_conversation: list[Message] = []
- for message_response in conversation_messages:
- if isinstance(message_response, MessageResponse):
- messages_from_conversation.append(
- message_response.message
- )
- ids.append(message_response.id)
- else:
- logger.warning(
- f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
- )
- messages = messages_from_conversation + messages
- else: # Create new conversation
- conversation_response = (
- await self.providers.database.conversations_handler.create_conversation()
- )
- conversation_id = conversation_response.id
- if message:
- messages.append(message)
- if not messages:
- raise R2RException(
- status_code=400,
- message="No messages to process",
- )
- current_message = messages[-1]
- # Save the new message to the conversation
- parent_id = ids[-1] if ids else None
- message_response = await self.providers.database.conversations_handler.add_message(
- conversation_id=conversation_id,
- content=current_message,
- parent_id=parent_id,
- )
- message_id = (
- message_response.id
- if message_response is not None
- else None
- )
- if rag_generation_config.stream:
- async def stream_response():
- async with manage_run(self.run_manager, "rag_agent"):
- agent = R2RStreamingRAGAgent(
- database_provider=self.providers.database,
- llm_provider=self.providers.llm,
- config=self.config.agent,
- search_pipeline=self.pipelines.search_pipeline,
- )
- async for chunk in agent.arun(
- messages=messages,
- system_instruction=task_prompt_override,
- search_settings=search_settings,
- rag_generation_config=rag_generation_config,
- include_title_if_available=include_title_if_available,
- ):
- yield chunk
- return stream_response()
- results = await self.agents.rag_agent.arun(
- messages=messages,
- system_instruction=task_prompt_override,
- search_settings=search_settings,
- rag_generation_config=rag_generation_config,
- include_title_if_available=include_title_if_available,
- )
- # Save the assistant's reply to the conversation
- if isinstance(results[-1], dict):
- assistant_message = Message(**results[-1])
- elif isinstance(results[-1], Message):
- assistant_message = results[-1]
- else:
- assistant_message = Message(
- role="assistant", content=str(results[-1])
- )
- await self.providers.database.conversations_handler.add_message(
- conversation_id=conversation_id,
- content=assistant_message,
- parent_id=message_id,
- )
- return {
- "messages": results,
- "conversation_id": str(
- conversation_id
- ), # Ensure it's a string
- }
- except Exception as e:
- logger.error(f"Error in agent response: {str(e)}")
- if "NoneType" in str(e):
- raise HTTPException(
- status_code=502,
- detail="Server not reachable or returned an invalid response",
- )
- raise HTTPException(
- status_code=500,
- detail=f"Internal Server Error - {str(e)}",
- )
- class RetrievalServiceAdapter:
- @staticmethod
- def _parse_user_data(user_data):
- if isinstance(user_data, str):
- try:
- user_data = json.loads(user_data)
- except json.JSONDecodeError:
- raise ValueError(f"Invalid user data format: {user_data}")
- return User.from_dict(user_data)
- @staticmethod
- def prepare_search_input(
- query: str,
- search_settings: SearchSettings,
- user: User,
- ) -> dict:
- return {
- "query": query,
- "search_settings": search_settings.to_dict(),
- "user": user.to_dict(),
- }
- @staticmethod
- def parse_search_input(data: dict):
- return {
- "query": data["query"],
- "search_settings": SearchSettings.from_dict(
- data["search_settings"]
- ),
- "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
- }
- @staticmethod
- def prepare_rag_input(
- query: str,
- search_settings: SearchSettings,
- rag_generation_config: GenerationConfig,
- task_prompt_override: Optional[str],
- user: User,
- ) -> dict:
- return {
- "query": query,
- "search_settings": search_settings.to_dict(),
- "rag_generation_config": rag_generation_config.to_dict(),
- "task_prompt_override": task_prompt_override,
- "user": user.to_dict(),
- }
- @staticmethod
- def parse_rag_input(data: dict):
- return {
- "query": data["query"],
- "search_settings": SearchSettings.from_dict(
- data["search_settings"]
- ),
- "rag_generation_config": GenerationConfig.from_dict(
- data["rag_generation_config"]
- ),
- "task_prompt_override": data["task_prompt_override"],
- "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
- }
- @staticmethod
- def prepare_agent_input(
- message: Message,
- search_settings: SearchSettings,
- rag_generation_config: GenerationConfig,
- task_prompt_override: Optional[str],
- include_title_if_available: bool,
- user: User,
- conversation_id: Optional[str] = None,
- ) -> dict:
- return {
- "message": message.to_dict(),
- "search_settings": search_settings.to_dict(),
- "rag_generation_config": rag_generation_config.to_dict(),
- "task_prompt_override": task_prompt_override,
- "include_title_if_available": include_title_if_available,
- "user": user.to_dict(),
- "conversation_id": conversation_id,
- }
- @staticmethod
- def parse_agent_input(data: dict):
- return {
- "message": Message.from_dict(data["message"]),
- "search_settings": SearchSettings.from_dict(
- data["search_settings"]
- ),
- "rag_generation_config": GenerationConfig.from_dict(
- data["rag_generation_config"]
- ),
- "task_prompt_override": data["task_prompt_override"],
- "include_title_if_available": data["include_title_if_available"],
- "user": RetrievalServiceAdapter._parse_user_data(data["user"]),
- "conversation_id": data.get("conversation_id"),
- }
|