file_search_tool.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. class FileSearchToolInput(BaseModel):
  8. indexes: List[int] = Field(
  9. ..., description="file index list to look up in retrieval"
  10. )
  11. query: str = Field(..., description="query to look up in retrieval")
  12. class FileSearchTool(BaseTool):
  13. name: str = "file_search"
  14. description: str = (
  15. "Can be used to look up information that was uploaded to this assistant."
  16. "If the user is referencing particular files, that is often a good hint that information may be here."
  17. )
  18. args_schema: Type[BaseModel] = FileSearchToolInput
  19. def __init__(self) -> None:
  20. super().__init__()
  21. self.__filenames = []
  22. self.__keys = []
  23. def configure(self, session: Session, run: Run, **kwargs):
  24. """
  25. 置当前 Retrieval 涉及文件信息
  26. """
  27. files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
  28. # pre-cache data to prevent thread conflicts that may occur later on.
  29. print(
  30. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  31. )
  32. print(files)
  33. for file in files:
  34. self.__filenames.append(file.filename)
  35. self.__keys.append(file.key)
  36. print(self.__keys)
  37. def run(self, indexes: List[int], query: str) -> dict:
  38. file_keys = []
  39. print(self.__keys)
  40. for index in indexes:
  41. file_key = self.__keys[index]
  42. file_keys.append(file_key)
  43. files = FileService.search_in_files(query=query, file_keys=file_keys)
  44. return files
  45. def instruction_supplement(self) -> str:
  46. """
  47. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  48. """
  49. if len(self.__filenames) == 0:
  50. return ""
  51. else:
  52. filenames_info = [
  53. f"({index}){filename}"
  54. for index, filename in enumerate(self.__filenames)
  55. ]
  56. return (
  57. 'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
  58. + 'Each line represents a file in the format "(index)filename":\n'
  59. + "\n".join(filenames_info)
  60. + "\nMake sure to be extremely concise when using attached files. "
  61. )