rag.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from typing import Union
  2. from core.agent import R2RAgent, R2RStreamingAgent
  3. from core.base import (
  4. format_search_results_for_llm,
  5. format_search_results_for_stream,
  6. )
  7. from core.base.abstractions import (
  8. AggregateSearchResult,
  9. GraphSearchSettings,
  10. SearchSettings,
  11. WebSearchResponse,
  12. )
  13. from core.base.agent import AgentConfig, Tool
  14. from core.base.providers import CompletionProvider, DatabaseProvider
  15. from core.base.utils import to_async_generator
  16. from core.pipelines import SearchPipeline
  17. from core.providers import ( # PostgresDatabaseProvider,
  18. LiteLLMCompletionProvider,
  19. OpenAICompletionProvider,
  20. )
  21. class RAGAgentMixin:
  22. def __init__(self, search_pipeline: SearchPipeline, *args, **kwargs):
  23. self.search_pipeline = search_pipeline
  24. super().__init__(*args, **kwargs)
  25. def _register_tools(self):
  26. if not self.config.tool_names:
  27. return
  28. for tool_name in self.config.tool_names:
  29. if tool_name == "local_search":
  30. self._tools.append(self.local_search())
  31. elif tool_name == "web_search":
  32. self._tools.append(self.web_search())
  33. else:
  34. raise ValueError(f"Unsupported tool name: {tool_name}")
  35. def web_search(self) -> Tool:
  36. return Tool(
  37. name="web_search",
  38. description="Search for information on the web.",
  39. results_function=self._web_search,
  40. llm_format_function=RAGAgentMixin.format_search_results_for_llm,
  41. stream_function=RAGAgentMixin.format_search_results_for_stream,
  42. parameters={
  43. "type": "object",
  44. "properties": {
  45. "query": {
  46. "type": "string",
  47. "description": "The query to search Google with.",
  48. },
  49. },
  50. "required": ["query"],
  51. },
  52. )
  53. async def _web_search(
  54. self,
  55. query: str,
  56. search_settings: SearchSettings,
  57. *args,
  58. **kwargs,
  59. ) -> list[AggregateSearchResult]:
  60. from .serper import SerperClient
  61. serper_client = SerperClient()
  62. # TODO - make async!
  63. # TODO - Move to search pipeline, make configurable.
  64. raw_results = serper_client.get_raw(query)
  65. web_response = WebSearchResponse.from_serper_results(raw_results)
  66. return AggregateSearchResult(
  67. chunk_search_results=None,
  68. graph_search_results=None,
  69. web_search_results=web_response.organic_results, # TODO - How do we feel about throwing away so much info?
  70. )
  71. def local_search(self) -> Tool:
  72. return Tool(
  73. name="local_search",
  74. description="Search your local knowledgebase using the R2R AI system",
  75. results_function=self._local_search,
  76. llm_format_function=RAGAgentMixin.format_search_results_for_llm,
  77. stream_function=RAGAgentMixin.format_search_results_for_stream,
  78. parameters={
  79. "type": "object",
  80. "properties": {
  81. "query": {
  82. "type": "string",
  83. "description": "The query to search the local knowledgebase with.",
  84. },
  85. },
  86. "required": ["query"],
  87. },
  88. )
  89. async def _local_search(
  90. self,
  91. query: str,
  92. search_settings: SearchSettings,
  93. *args,
  94. **kwargs,
  95. ) -> list[AggregateSearchResult]:
  96. response = await self.search_pipeline.run(
  97. to_async_generator([query]),
  98. state=None,
  99. search_settings=search_settings,
  100. )
  101. return response
  102. @staticmethod
  103. def format_search_results_for_stream(
  104. results: AggregateSearchResult,
  105. ) -> str:
  106. return format_search_results_for_stream(results)
  107. @staticmethod
  108. def format_search_results_for_llm(
  109. results: AggregateSearchResult,
  110. ) -> str:
  111. return format_search_results_for_llm(results)
  112. class R2RRAGAgent(RAGAgentMixin, R2RAgent):
  113. def __init__(
  114. self,
  115. database_provider: DatabaseProvider,
  116. llm_provider: Union[
  117. LiteLLMCompletionProvider, OpenAICompletionProvider
  118. ],
  119. search_pipeline: SearchPipeline,
  120. config: AgentConfig,
  121. ):
  122. super().__init__(
  123. database_provider=database_provider,
  124. search_pipeline=search_pipeline,
  125. llm_provider=llm_provider,
  126. config=config,
  127. )
  128. class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
  129. def __init__(
  130. self,
  131. database_provider: DatabaseProvider,
  132. llm_provider: Union[
  133. LiteLLMCompletionProvider, OpenAICompletionProvider
  134. ],
  135. search_pipeline: SearchPipeline,
  136. config: AgentConfig,
  137. ):
  138. config.stream = True
  139. super().__init__(
  140. database_provider=database_provider,
  141. search_pipeline=search_pipeline,
  142. llm_provider=llm_provider,
  143. config=config,
  144. )