base_tool.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from abc import ABC
  2. from typing import Type, Dict, Any, Optional
  3. from langchain.tools import BaseTool as LCBaseTool
  4. from langchain.tools.render import format_tool_to_openai_function
  5. from pydantic import BaseModel, Field
  6. class BaseToolInput(BaseModel):
  7. """
  8. Base schema for tool input arguments.
  9. """
  10. input: str = Field(..., description="input")
  11. class BaseTool(ABC):
  12. """
  13. Base class for tools.
  14. Attributes:
  15. name (str): The name of the tool.
  16. description (str): The description of the tool.
  17. args_schema (Optional[Type[BaseModel]]): The schema for the tool's input arguments.
  18. openai_function (Dict): The OpenAI function representation of the tool.
  19. """
  20. name: str
  21. description: str
  22. args_schema: Optional[Type[BaseModel]] = BaseToolInput
  23. openai_function: Dict
  24. def __init_subclass__(cls) -> None:
  25. lc_tool = LCTool(name=cls.name, description=cls.description, args_schema=cls.args_schema, _run=lambda x: x)
  26. cls.openai_function = {"type": "function", "function": dict(format_tool_to_openai_function(lc_tool))}
  27. def configure(self, **kwargs):
  28. """
  29. Configure the tool with the provided keyword arguments.
  30. Args:
  31. **kwargs: Additional configuration parameters.
  32. """
  33. return
  34. def run(self, **kwargs) -> Any:
  35. """
  36. Executes the tool with the given arguments.
  37. Args:
  38. **kwargs: Additional keyword arguments for the tool.
  39. Returns:
  40. Any: The result of executing the tool.
  41. """
  42. raise NotImplementedError()
  43. def instruction_supplement(self) -> str:
  44. """
  45. Provides additional instructions to supplement the run instruction for the tool.
  46. Returns:
  47. str: The additional instructions.
  48. """
  49. return ""
  50. class LCTool(LCBaseTool):
  51. name: str = ""
  52. description: str = ""
  53. def _run(self):
  54. pass