rag.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # type: ignore
  2. import logging
  3. from typing import Callable, Optional
  4. from core.base import (
  5. format_search_results_for_llm,
  6. )
  7. from core.base.abstractions import (
  8. AggregateSearchResult,
  9. GenerationConfig,
  10. SearchSettings,
  11. )
  12. from core.base.agent.tools.registry import ToolRegistry
  13. from core.base.providers import DatabaseProvider
  14. from core.providers import (
  15. AnthropicCompletionProvider,
  16. LiteLLMCompletionProvider,
  17. OpenAICompletionProvider,
  18. R2RCompletionProvider,
  19. )
  20. from core.utils import (
  21. SearchResultsCollector,
  22. num_tokens,
  23. )
  24. from ..base.agent.agent import RAGAgentConfig
  25. # Import the base classes from the refactored base file
  26. from .base import (
  27. R2RAgent,
  28. R2RStreamingAgent,
  29. R2RXMLStreamingAgent,
  30. R2RXMLToolsAgent,
  31. )
  32. logger = logging.getLogger(__name__)
  33. class RAGAgentMixin:
  34. """
  35. A Mixin for adding search_file_knowledge, web_search, and content tools
  36. to your R2R Agents. This allows your agent to:
  37. - call knowledge_search_method (semantic/hybrid search)
  38. - call content_method (fetch entire doc/chunk structures)
  39. - call an external web search API
  40. """
  41. def __init__(
  42. self,
  43. *args,
  44. search_settings: SearchSettings,
  45. knowledge_search_method: Callable,
  46. content_method: Callable,
  47. file_search_method: Callable,
  48. max_tool_context_length=10_000,
  49. max_context_window_tokens=512_000,
  50. tool_registry: Optional[ToolRegistry] = None,
  51. **kwargs,
  52. ):
  53. # Save references to the retrieval logic
  54. self.search_settings = search_settings
  55. self.knowledge_search_method = knowledge_search_method
  56. self.content_method = content_method
  57. self.file_search_method = file_search_method
  58. self.max_tool_context_length = max_tool_context_length
  59. self.max_context_window_tokens = max_context_window_tokens
  60. self.search_results_collector = SearchResultsCollector()
  61. self.tool_registry = tool_registry or ToolRegistry()
  62. super().__init__(*args, **kwargs)
  63. def _register_tools(self):
  64. """
  65. Register all requested tools from self.config.rag_tools using the ToolRegistry.
  66. """
  67. if not self.config.rag_tools:
  68. logger.warning(
  69. "No RAG tools requested. Skipping tool registration."
  70. )
  71. return
  72. # Make sure tool_registry exists
  73. if not hasattr(self, "tool_registry") or self.tool_registry is None:
  74. self.tool_registry = ToolRegistry()
  75. format_function = self.format_search_results_for_llm
  76. for tool_name in set(self.config.rag_tools):
  77. # Try to get the tools from the registry
  78. if tool_instance := self.tool_registry.create_tool_instance(
  79. tool_name, format_function, context=self
  80. ):
  81. logger.debug(
  82. f"Successfully registered tool from registry: {tool_name}"
  83. )
  84. self._tools.append(tool_instance)
  85. else:
  86. logger.warning(f"Unknown tool requested: {tool_name}")
  87. logger.debug(f"Registered {len(self._tools)} RAG tools.")
  88. def format_search_results_for_llm(
  89. self, results: AggregateSearchResult
  90. ) -> str:
  91. context = format_search_results_for_llm(results)
  92. context_tokens = num_tokens(context) + 1
  93. frac_to_return = self.max_tool_context_length / (context_tokens)
  94. if frac_to_return > 1:
  95. return context
  96. else:
  97. return context[: int(frac_to_return * len(context))]
  98. class R2RRAGAgent(RAGAgentMixin, R2RAgent):
  99. """
  100. Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
  101. """
  102. def __init__(
  103. self,
  104. database_provider: DatabaseProvider,
  105. llm_provider: (
  106. AnthropicCompletionProvider
  107. | LiteLLMCompletionProvider
  108. | OpenAICompletionProvider
  109. | R2RCompletionProvider
  110. ),
  111. config: RAGAgentConfig,
  112. search_settings: SearchSettings,
  113. rag_generation_config: GenerationConfig,
  114. knowledge_search_method: Callable,
  115. content_method: Callable,
  116. file_search_method: Callable,
  117. tool_registry: Optional[ToolRegistry] = None,
  118. max_tool_context_length: int = 20_000,
  119. ):
  120. # Initialize base R2RAgent
  121. R2RAgent.__init__(
  122. self,
  123. database_provider=database_provider,
  124. llm_provider=llm_provider,
  125. config=config,
  126. rag_generation_config=rag_generation_config,
  127. )
  128. self.tool_registry = tool_registry or ToolRegistry()
  129. # Initialize the RAGAgentMixin
  130. RAGAgentMixin.__init__(
  131. self,
  132. database_provider=database_provider,
  133. llm_provider=llm_provider,
  134. config=config,
  135. search_settings=search_settings,
  136. rag_generation_config=rag_generation_config,
  137. max_tool_context_length=max_tool_context_length,
  138. knowledge_search_method=knowledge_search_method,
  139. file_search_method=file_search_method,
  140. content_method=content_method,
  141. tool_registry=tool_registry,
  142. )
  143. self._register_tools()
  144. class R2RXMLToolsRAGAgent(RAGAgentMixin, R2RXMLToolsAgent):
  145. """
  146. Non-streaming RAG Agent that supports search_file_knowledge, content, web_search.
  147. """
  148. def __init__(
  149. self,
  150. database_provider: DatabaseProvider,
  151. llm_provider: (
  152. AnthropicCompletionProvider
  153. | LiteLLMCompletionProvider
  154. | OpenAICompletionProvider
  155. | R2RCompletionProvider
  156. ),
  157. config: RAGAgentConfig,
  158. search_settings: SearchSettings,
  159. rag_generation_config: GenerationConfig,
  160. knowledge_search_method: Callable,
  161. content_method: Callable,
  162. file_search_method: Callable,
  163. tool_registry: Optional[ToolRegistry] = None,
  164. max_tool_context_length: int = 20_000,
  165. ):
  166. # Initialize base R2RAgent
  167. R2RXMLToolsAgent.__init__(
  168. self,
  169. database_provider=database_provider,
  170. llm_provider=llm_provider,
  171. config=config,
  172. rag_generation_config=rag_generation_config,
  173. )
  174. self.tool_registry = tool_registry or ToolRegistry()
  175. # Initialize the RAGAgentMixin
  176. RAGAgentMixin.__init__(
  177. self,
  178. database_provider=database_provider,
  179. llm_provider=llm_provider,
  180. config=config,
  181. search_settings=search_settings,
  182. rag_generation_config=rag_generation_config,
  183. max_tool_context_length=max_tool_context_length,
  184. knowledge_search_method=knowledge_search_method,
  185. file_search_method=file_search_method,
  186. content_method=content_method,
  187. tool_registry=tool_registry,
  188. )
  189. self._register_tools()
  190. class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
  191. """
  192. Streaming-capable RAG Agent that supports search_file_knowledge, content, web_search,
  193. and emits citations as [abc1234] short IDs if the LLM includes them in brackets.
  194. """
  195. def __init__(
  196. self,
  197. database_provider: DatabaseProvider,
  198. llm_provider: (
  199. AnthropicCompletionProvider
  200. | LiteLLMCompletionProvider
  201. | OpenAICompletionProvider
  202. | R2RCompletionProvider
  203. ),
  204. config: RAGAgentConfig,
  205. search_settings: SearchSettings,
  206. rag_generation_config: GenerationConfig,
  207. knowledge_search_method: Callable,
  208. content_method: Callable,
  209. file_search_method: Callable,
  210. tool_registry: Optional[ToolRegistry] = None,
  211. max_tool_context_length: int = 10_000,
  212. ):
  213. # Force streaming on
  214. config.stream = True
  215. # Initialize base R2RStreamingAgent
  216. R2RStreamingAgent.__init__(
  217. self,
  218. database_provider=database_provider,
  219. llm_provider=llm_provider,
  220. config=config,
  221. rag_generation_config=rag_generation_config,
  222. )
  223. self.tool_registry = tool_registry or ToolRegistry()
  224. # Initialize the RAGAgentMixin
  225. RAGAgentMixin.__init__(
  226. self,
  227. database_provider=database_provider,
  228. llm_provider=llm_provider,
  229. config=config,
  230. search_settings=search_settings,
  231. rag_generation_config=rag_generation_config,
  232. max_tool_context_length=max_tool_context_length,
  233. knowledge_search_method=knowledge_search_method,
  234. content_method=content_method,
  235. file_search_method=file_search_method,
  236. tool_registry=tool_registry,
  237. )
  238. self._register_tools()
  239. class R2RXMLToolsStreamingRAGAgent(RAGAgentMixin, R2RXMLStreamingAgent):
  240. """
  241. A streaming agent that:
  242. - treats <think> or <Thought> blocks as chain-of-thought
  243. and emits them incrementally as SSE "thinking" events.
  244. - accumulates user-visible text outside those tags as SSE "message" events.
  245. - filters out all XML tags related to tool calls and actions.
  246. - upon finishing each iteration, it parses <Action><ToolCalls><ToolCall> blocks,
  247. calls the appropriate tool, and emits SSE "tool_call" / "tool_result".
  248. - properly emits citations when they appear in the text
  249. """
  250. def __init__(
  251. self,
  252. database_provider: DatabaseProvider,
  253. llm_provider: (
  254. AnthropicCompletionProvider
  255. | LiteLLMCompletionProvider
  256. | OpenAICompletionProvider
  257. | R2RCompletionProvider
  258. ),
  259. config: RAGAgentConfig,
  260. search_settings: SearchSettings,
  261. rag_generation_config: GenerationConfig,
  262. knowledge_search_method: Callable,
  263. content_method: Callable,
  264. file_search_method: Callable,
  265. tool_registry: Optional[ToolRegistry] = None,
  266. max_tool_context_length: int = 10_000,
  267. ):
  268. # Force streaming on
  269. config.stream = True
  270. # Initialize base R2RXMLStreamingAgent
  271. R2RXMLStreamingAgent.__init__(
  272. self,
  273. database_provider=database_provider,
  274. llm_provider=llm_provider,
  275. config=config,
  276. rag_generation_config=rag_generation_config,
  277. )
  278. self.tool_registry = tool_registry or ToolRegistry()
  279. # Initialize the RAGAgentMixin
  280. RAGAgentMixin.__init__(
  281. self,
  282. database_provider=database_provider,
  283. llm_provider=llm_provider,
  284. config=config,
  285. search_settings=search_settings,
  286. rag_generation_config=rag_generation_config,
  287. max_tool_context_length=max_tool_context_length,
  288. knowledge_search_method=knowledge_search_method,
  289. content_method=content_method,
  290. file_search_method=file_search_method,
  291. tool_registry=tool_registry,
  292. )
  293. self._register_tools()