registry.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import importlib
  2. import inspect
  3. import logging
  4. import os
  5. import pkgutil
  6. import sys
  7. from typing import Callable, Optional, Type
  8. from shared.abstractions.tool import Tool
  9. logger = logging.getLogger(__name__)
  10. class ToolRegistry:
  11. """
  12. Registry for discovering and managing tools from both
  13. built-in sources and user-defined extensions.
  14. """
  15. def __init__(
  16. self,
  17. built_in_path: str | None = None,
  18. user_tools_path: str | None = None,
  19. ):
  20. self.built_in_path = built_in_path or os.path.join(
  21. os.path.dirname(os.path.abspath(__file__)), "built_in"
  22. )
  23. self.user_tools_path = (
  24. user_tools_path
  25. or os.getenv("R2R_USER_TOOLS_PATH")
  26. or "../docker/user_tools"
  27. )
  28. # Tool storage
  29. self._built_in_tools: dict[str, Type[Tool]] = {}
  30. self._user_tools: dict[str, Type[Tool]] = {}
  31. # Discover tools
  32. self._discover_built_in_tools()
  33. if os.path.exists(self.user_tools_path):
  34. self._discover_user_tools()
  35. else:
  36. logger.warning(
  37. f"User tools directory not found: {self.user_tools_path}"
  38. )
  39. def _discover_built_in_tools(self):
  40. """Load all built-in tools from the built_in directory."""
  41. if not os.path.exists(self.built_in_path):
  42. logger.warning(
  43. f"Built-in tools directory not found: {self.built_in_path}"
  44. )
  45. return
  46. # Add to Python path if needed
  47. if self.built_in_path not in sys.path:
  48. sys.path.append(os.path.dirname(self.built_in_path))
  49. # Import the built_in package
  50. try:
  51. built_in_pkg = importlib.import_module("built_in")
  52. except ImportError:
  53. logger.error("Failed to import built_in tools package")
  54. return
  55. # Discover all modules in the package
  56. for _, module_name, is_pkg in pkgutil.iter_modules(
  57. [self.built_in_path]
  58. ):
  59. if is_pkg: # Skip subpackages
  60. continue
  61. try:
  62. module = importlib.import_module(f"built_in.{module_name}")
  63. # Find all tool classes in the module
  64. for name, obj in inspect.getmembers(module, inspect.isclass):
  65. if (
  66. issubclass(obj, Tool)
  67. and obj.__module__ == module.__name__
  68. and obj != Tool
  69. ):
  70. try:
  71. tool_instance = obj()
  72. self._built_in_tools[tool_instance.name] = obj
  73. logger.debug(
  74. f"Loaded built-in tool: {tool_instance.name}"
  75. )
  76. except Exception as e:
  77. logger.error(
  78. f"Error instantiating built-in tool {name}: {e}"
  79. )
  80. except Exception as e:
  81. logger.error(
  82. f"Error loading built-in tool module {module_name}: {e}"
  83. )
  84. def _discover_user_tools(self):
  85. """Scan the user tools directory for custom tools."""
  86. # Add user_tools directory to Python path if needed
  87. if self.user_tools_path not in sys.path:
  88. sys.path.append(os.path.dirname(self.user_tools_path))
  89. user_tools_pkg_name = os.path.basename(self.user_tools_path)
  90. # Check all Python files in user_tools directory
  91. for filename in os.listdir(self.user_tools_path):
  92. if (
  93. not filename.endswith(".py")
  94. or filename.startswith("_")
  95. or filename.startswith(".")
  96. ):
  97. continue
  98. module_name = filename[:-3] # Remove .py extension
  99. try:
  100. # Import the module
  101. module = importlib.import_module(
  102. f"{user_tools_pkg_name}.{module_name}"
  103. )
  104. # Find all tool classes in the module
  105. for name, obj in inspect.getmembers(module, inspect.isclass):
  106. if (
  107. issubclass(obj, Tool)
  108. and obj.__module__ == module.__name__
  109. and obj != Tool
  110. ):
  111. try:
  112. tool_instance = obj()
  113. self._user_tools[tool_instance.name] = obj
  114. logger.debug(
  115. f"Loaded user tool: {tool_instance.name}"
  116. )
  117. except Exception as e:
  118. logger.error(
  119. f"Error instantiating user tool {name}: {e}"
  120. )
  121. except Exception as e:
  122. logger.error(
  123. f"Error loading user tool module {module_name}: {e}"
  124. )
  125. def get_tool_class(self, tool_name: str):
  126. """Get a tool class by name."""
  127. if tool_name in self._user_tools:
  128. return self._user_tools[tool_name]
  129. return self._built_in_tools.get(tool_name)
  130. def list_available_tools(
  131. self, include_built_in=True, include_user=True
  132. ) -> list[str]:
  133. """
  134. List all available tool names.
  135. Optionally filter by built-in or user-defined tools.
  136. """
  137. tools: set[str] = set()
  138. if include_built_in:
  139. tools.update(self._built_in_tools.keys())
  140. if include_user:
  141. tools.update(self._user_tools.keys())
  142. return sorted(list(tools))
  143. def create_tool_instance(
  144. self, tool_name: str, format_function: Callable, context=None
  145. ) -> Optional[Tool]:
  146. """
  147. Create, configure, and return an instance of the specified tool.
  148. Returns None if the tool doesn't exist or instantiation fails.
  149. """
  150. tool_class = self.get_tool_class(tool_name)
  151. if not tool_class:
  152. logger.warning(f"Tool class not found for '{tool_name}'")
  153. return None
  154. try:
  155. tool_instance = tool_class()
  156. if hasattr(tool_instance, "llm_format_function"):
  157. tool_instance.llm_format_function = format_function
  158. # Set the context on the specific tool instance
  159. tool_instance.set_context(context)
  160. return tool_instance
  161. except Exception as e:
  162. logger.error(
  163. f"Error creating or setting context for tool instance '{tool_name}': {e}"
  164. )
  165. return None