file_search_tool_.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import fnmatch
  2. import os
  3. from typing import Optional, Type
  4. from langchain_core.callbacks import CallbackManagerForToolRun
  5. from langchain_core.tools import BaseTool
  6. from pydantic import BaseModel, Field
  7. from langchain_community.tools.file_management.utils import (
  8. INVALID_PATH_TEMPLATE,
  9. BaseFileToolMixin,
  10. FileValidationError,
  11. )
  12. from app.services.file.file import FileService
  13. class FileSearchInput(BaseModel):
  14. """Input for FileSearchTool."""
  15. dir_path: str = Field(
  16. default=".",
  17. description="Subdirectory to search in.",
  18. )
  19. pattern: str = Field(
  20. ...,
  21. description="Unix shell regex, where * matches everything.",
  22. )
  23. class FileSearchTool(BaseFileToolMixin, BaseTool): # type: ignore[override, override]
  24. """Tool that searches for files in a subdirectory that match a regex pattern."""
  25. name: str = "file_search"
  26. args_schema: Type[BaseModel] = FileSearchInput
  27. description: str = (
  28. "Recursively search for files in a subdirectory that match the regex pattern"
  29. )
  30. def run(self, indexes: List[int], query: str) -> dict:
  31. file_keys = []
  32. print(self.__keys)
  33. for index in indexes:
  34. file_key = self.__keys[index]
  35. file_keys.append(file_key)
  36. files = FileService.search_in_files(query=query, file_keys=file_keys)
  37. return files
  38. """
  39. def _run(
  40. self,
  41. pattern: str,
  42. dir_path: str = ".",
  43. run_manager: Optional[CallbackManagerForToolRun] = None,
  44. ) -> str:
  45. try:
  46. dir_path_ = self.get_relative_path(dir_path)
  47. except FileValidationError:
  48. return INVALID_PATH_TEMPLATE.format(arg_name="dir_path", value=dir_path)
  49. matches = []
  50. try:
  51. for root, _, filenames in os.walk(dir_path_):
  52. for filename in fnmatch.filter(filenames, pattern):
  53. absolute_path = os.path.join(root, filename)
  54. relative_path = os.path.relpath(absolute_path, dir_path_)
  55. matches.append(relative_path)
  56. if matches:
  57. return "\n".join(matches)
  58. else:
  59. return f"No files found for pattern {pattern} in directory {dir_path}"
  60. except Exception as e:
  61. return "Error: " + str(e)
  62. # TODO: Add aiofiles method
  63. """