rag.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from core.agent import R2RAgent, R2RStreamingAgent
  2. from core.base import (
  3. format_search_results_for_llm,
  4. format_search_results_for_stream,
  5. )
  6. from core.base.abstractions import (
  7. AggregateSearchResult,
  8. GraphSearchSettings,
  9. SearchSettings,
  10. WebSearchResponse,
  11. )
  12. from core.base.agent import AgentConfig, Tool
  13. from core.base.providers import CompletionProvider, DatabaseProvider
  14. from core.base.utils import to_async_generator
  15. from core.pipelines import SearchPipeline
  16. from core.providers import ( # PostgresDatabaseProvider,
  17. LiteLLMCompletionProvider,
  18. OpenAICompletionProvider,
  19. )
  20. class RAGAgentMixin:
  21. def __init__(self, search_pipeline: SearchPipeline, *args, **kwargs):
  22. self.search_pipeline = search_pipeline
  23. super().__init__(*args, **kwargs)
  24. def _register_tools(self):
  25. if not self.config.tool_names:
  26. return
  27. for tool_name in self.config.tool_names:
  28. if tool_name == "local_search":
  29. self._tools.append(self.local_search())
  30. elif tool_name == "web_search":
  31. self._tools.append(self.web_search())
  32. else:
  33. raise ValueError(f"Unsupported tool name: {tool_name}")
  34. def web_search(self) -> Tool:
  35. return Tool(
  36. name="web_search",
  37. description="Search for information on the web.",
  38. results_function=self._web_search,
  39. llm_format_function=RAGAgentMixin.format_search_results_for_llm,
  40. stream_function=RAGAgentMixin.format_search_results_for_stream,
  41. parameters={
  42. "type": "object",
  43. "properties": {
  44. "query": {
  45. "type": "string",
  46. "description": "The query to search Google with.",
  47. },
  48. },
  49. "required": ["query"],
  50. },
  51. )
  52. async def _web_search(
  53. self,
  54. query: str,
  55. search_settings: SearchSettings,
  56. *args,
  57. **kwargs,
  58. ) -> list[AggregateSearchResult]:
  59. from .serper import SerperClient
  60. serper_client = SerperClient()
  61. # TODO - make async!
  62. # TODO - Move to search pipeline, make configurable.
  63. raw_results = serper_client.get_raw(query)
  64. web_response = WebSearchResponse.from_serper_results(raw_results)
  65. return AggregateSearchResult(
  66. chunk_search_results=None,
  67. graph_search_results=None,
  68. web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
  69. )
  70. def local_search(self) -> Tool:
  71. return Tool(
  72. name="local_search",
  73. description="Search your local knowledgebase using the R2R AI system",
  74. results_function=self._local_search,
  75. llm_format_function=RAGAgentMixin.format_search_results_for_llm,
  76. stream_function=RAGAgentMixin.format_search_results_for_stream,
  77. parameters={
  78. "type": "object",
  79. "properties": {
  80. "query": {
  81. "type": "string",
  82. "description": "The query to search the local knowledgebase with.",
  83. },
  84. },
  85. "required": ["query"],
  86. },
  87. )
  88. async def _local_search(
  89. self,
  90. query: str,
  91. search_settings: SearchSettings,
  92. *args,
  93. **kwargs,
  94. ) -> list[AggregateSearchResult]:
  95. response = await self.search_pipeline.run(
  96. to_async_generator([query]),
  97. state=None,
  98. search_settings=search_settings,
  99. )
  100. return response
  101. @staticmethod
  102. def format_search_results_for_stream(
  103. results: AggregateSearchResult,
  104. ) -> str:
  105. return format_search_results_for_stream(results)
  106. @staticmethod
  107. def format_search_results_for_llm(
  108. results: AggregateSearchResult,
  109. ) -> str:
  110. return format_search_results_for_llm(results)
  111. class R2RRAGAgent(RAGAgentMixin, R2RAgent):
  112. def __init__(
  113. self,
  114. database_provider: DatabaseProvider,
  115. llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
  116. search_pipeline: SearchPipeline,
  117. config: AgentConfig,
  118. ):
  119. super().__init__(
  120. database_provider=database_provider,
  121. search_pipeline=search_pipeline,
  122. llm_provider=llm_provider,
  123. config=config,
  124. )
  125. class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
  126. def __init__(
  127. self,
  128. database_provider: DatabaseProvider,
  129. llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
  130. search_pipeline: SearchPipeline,
  131. config: AgentConfig,
  132. ):
  133. config.stream = True
  134. super().__init__(
  135. database_provider=database_provider,
  136. search_pipeline=search_pipeline,
  137. llm_provider=llm_provider,
  138. config=config,
  139. )