123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from core.agent import R2RAgent, R2RStreamingAgent
- from core.base import (
- format_search_results_for_llm,
- format_search_results_for_stream,
- )
- from core.base.abstractions import (
- AggregateSearchResult,
- GraphSearchSettings,
- SearchSettings,
- WebSearchResponse,
- )
- from core.base.agent import AgentConfig, Tool
- from core.base.providers import CompletionProvider, DatabaseProvider
- from core.base.utils import to_async_generator
- from core.pipelines import SearchPipeline
- from core.providers import ( # PostgresDatabaseProvider,
- LiteLLMCompletionProvider,
- OpenAICompletionProvider,
- )
- class RAGAgentMixin:
- def __init__(self, search_pipeline: SearchPipeline, *args, **kwargs):
- self.search_pipeline = search_pipeline
- super().__init__(*args, **kwargs)
- def _register_tools(self):
- if not self.config.tool_names:
- return
- for tool_name in self.config.tool_names:
- if tool_name == "local_search":
- self._tools.append(self.local_search())
- elif tool_name == "web_search":
- self._tools.append(self.web_search())
- else:
- raise ValueError(f"Unsupported tool name: {tool_name}")
- def web_search(self) -> Tool:
- return Tool(
- name="web_search",
- description="Search for information on the web.",
- results_function=self._web_search,
- llm_format_function=RAGAgentMixin.format_search_results_for_llm,
- stream_function=RAGAgentMixin.format_search_results_for_stream,
- parameters={
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The query to search Google with.",
- },
- },
- "required": ["query"],
- },
- )
- async def _web_search(
- self,
- query: str,
- search_settings: SearchSettings,
- *args,
- **kwargs,
- ) -> list[AggregateSearchResult]:
- from .serper import SerperClient
- serper_client = SerperClient()
- # TODO - make async!
- # TODO - Move to search pipeline, make configurable.
- raw_results = serper_client.get_raw(query)
- web_response = WebSearchResponse.from_serper_results(raw_results)
- return AggregateSearchResult(
- chunk_search_results=None,
- graph_search_results=None,
- web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
- )
- def local_search(self) -> Tool:
- return Tool(
- name="local_search",
- description="Search your local knowledgebase using the R2R AI system",
- results_function=self._local_search,
- llm_format_function=RAGAgentMixin.format_search_results_for_llm,
- stream_function=RAGAgentMixin.format_search_results_for_stream,
- parameters={
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "The query to search the local knowledgebase with.",
- },
- },
- "required": ["query"],
- },
- )
- async def _local_search(
- self,
- query: str,
- search_settings: SearchSettings,
- *args,
- **kwargs,
- ) -> list[AggregateSearchResult]:
- response = await self.search_pipeline.run(
- to_async_generator([query]),
- state=None,
- search_settings=search_settings,
- )
- return response
- @staticmethod
- def format_search_results_for_stream(
- results: AggregateSearchResult,
- ) -> str:
- return format_search_results_for_stream(results)
- @staticmethod
- def format_search_results_for_llm(
- results: AggregateSearchResult,
- ) -> str:
- return format_search_results_for_llm(results)
- class R2RRAGAgent(RAGAgentMixin, R2RAgent):
- def __init__(
- self,
- database_provider: DatabaseProvider,
- llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
- search_pipeline: SearchPipeline,
- config: AgentConfig,
- ):
- super().__init__(
- database_provider=database_provider,
- search_pipeline=search_pipeline,
- llm_provider=llm_provider,
- config=config,
- )
- class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
- def __init__(
- self,
- database_provider: DatabaseProvider,
- llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
- search_pipeline: SearchPipeline,
- config: AgentConfig,
- ):
- config.stream = True
- super().__init__(
- database_provider=database_provider,
- search_pipeline=search_pipeline,
- llm_provider=llm_provider,
- config=config,
- )
|