file_search_tool.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from typing import Type, List
  2. from pydantic import BaseModel, Field
  3. from sqlalchemy.orm import Session
  4. from app.core.tools.base_tool import BaseTool
  5. from app.models.run import Run
  6. from app.services.file.file import FileService
  7. from app.services.assistant.assistant import AssistantService
  8. # return '## important:You can use the "retrieval" tool to search for relevant information.\n If you are asking about the content of the files, please specify any keywords, topics, or context you are looking for to help retrieve the most relevant content.'
  9. # query: str = Field(
  10. # ...,
  11. # description="query to look up in retrieval",
  12. # )
  13. # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  14. # query: str = Field(..., description="query to look up in retrieval")
  15. class FileSearchToolInput(BaseModel):
  16. query: str = Field(
  17. ...,
  18. description="query to look up in retrieval",
  19. )
  20. class FileSearchTool(BaseTool):
  21. name: str = "file_search"
  22. description: str = (
  23. # "Can be used to search through content of files uploaded by the user."
  24. # + "If the user references specific file content (e.g., 'in my uploaded document...'), this function should be triggered."
  25. # + "Singleton operation: Strictly 1 invocation per API call"
  26. "Use this function to search through the content of files uploaded by the user. "
  27. + "Trigger this function whenever the user refers to content within any uploaded document or file, even if they do not specify a file name or type. "
  28. + "If multiple files are uploaded and the user does not indicate a specific file, perform the search across all available uploaded files and return relevant results from each, clearly stating which file each result comes from. "
  29. + "If the user's request is ambiguous, default to considering all relevant uploaded files, and, if possible, provide a brief summary of the contents of each file to help clarify. "
  30. + "If the user has uploaded files in multiple batches, and their request is ambiguous (e.g. 'summarize the documents'), default to summarizing only the most recent batch of uploaded files. If the user's intent is to include older files, they should specify this explicitly."
  31. + "Prioritize providing useful information by erring on the side of inclusion rather than exclusion when the user's intent is not explicit. "
  32. + "Singleton operation: Strictly 1 invocation per API call."
  33. )
  34. args_schema: Type[BaseModel] = FileSearchToolInput
  35. def __init__(self) -> None:
  36. super().__init__()
  37. self.__filenames = []
  38. self.__keys = []
  39. self.__dirkeys = []
  40. self.loop = None
  41. self.index = 0
  42. def configure(self, session: Session, run: Run, **kwargs):
  43. # 获取当前事件循环
  44. # document_id = []
  45. file_key = []
  46. # filesinfo = []
  47. # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
  48. for key in run.file_ids:
  49. if len(key) == 36:
  50. self.__keys.append(key) # 添加文件id 作为检索
  51. else:
  52. file_key.append(
  53. key
  54. ) ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
  55. print(
  56. "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
  57. )
  58. # print(document_id)
  59. print(file_key)
  60. files = []
  61. # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
  62. if len(file_key) > 0:
  63. ## 获取文件信息
  64. files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
  65. for file in files:
  66. self.__keys.append(file.key)
  67. print(files)
  68. """
  69. # 读取assistant的数据,获取文件夹的id
  70. db_asst = AssistantService.get_assistant_sync(
  71. session=session, assistant_id=run.assistant_id
  72. )
  73. if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
  74. ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
  75. asst_folder_ids = (
  76. db_asst.tool_resources.get("file_search")
  77. .get("vector_stores")[0]
  78. .get("folder_ids")
  79. )
  80. print(asst_folder_ids)
  81. # folder_fileinfo = []
  82. if asst_folder_ids:
  83. self.__dirkeys = asst_folder_ids
  84. """
  85. # pre-cache data to prevent thread conflicts that may occur later on.
  86. print(
  87. "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
  88. )
  89. print(self.__dirkeys)
  90. print(self.__keys)
  91. # indexes: List[int],
  92. def run(self, query: str) -> dict:
  93. print(
  94. "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
  95. )
  96. print(self.__keys)
  97. print(self.__dirkeys)
  98. files = []
  99. ## 必须有总结的内容query和才能触发
  100. if self.index == 0 and query:
  101. try:
  102. files = FileService.search_in_files(
  103. query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
  104. )
  105. self.index = 1
  106. except Exception as e:
  107. print(e)
  108. # print(files)
  109. return files
  110. def instruction_supplement(self) -> str:
  111. """
  112. 为 Retrieval 提供文件选择信息,用于 llm 调用抉择
  113. """
  114. if (self.__keys and len(self.__keys) > 0) or (
  115. self.__dirkeys and len(self.__dirkeys) > 0
  116. ):
  117. return ""
  118. else:
  119. return "如果您不确定用户发的文件内容或者代码库结构,请使用文件搜索工具读取内容并收集相关信息,不要瞎猜或者编造答案。"