jack 1 月之前
父節點
當前提交
52e296e954
共有 1 個文件被更改,包括 50 次插入56 次删除
  1. 50 56
      app/core/tools/file_search_tool.py

+ 50 - 56
app/core/tools/file_search_tool.py

@@ -14,10 +14,13 @@ import nest_asyncio
 # 使得异步代码可以在已运行的事件循环中嵌套
 nest_asyncio.apply()
 
-
 # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-# query: str = Field(..., description="query to look up in retrieval")
+
+
 class FileSearchToolInput(BaseModel):
+    indexes: List[int] = Field(
+        ..., description="file index list to look up in retrieval"
+    )
     query: str = Field(..., description="query to look up in retrieval")
 
 
@@ -25,10 +28,7 @@ class FileSearchTool(BaseTool):
     name: str = "file_search"
     description: str = (
         "Can be used to look up information that was uploaded to this assistant."
-        # "If the user is referencing particular files, that is often a good hint that information may be here."
-        "A search engine optimized for comprehensive, accurate, and trusted results. "
-        "Useful for when you need to answer questions about current events. "
-        "Input should be a search query."
+        "If the user is referencing particular files, that is often a good hint that information may be here."
     )
 
     args_schema: Type[BaseModel] = FileSearchToolInput
@@ -37,38 +37,41 @@ class FileSearchTool(BaseTool):
         super().__init__()
         self.__filenames = []
         self.__keys = []
-        self.__dirkeys = []
         self.loop = None
 
     def configure(self, session: Session, run: Run, **kwargs):
+        """
+                # 提交任务到事件循环
+        future = asyncio.run_coroutine_threadsafe(async_task(), loop)
+        # 阻塞等待结果
+        result = future.result()
+        """
+        """
+        置当前 Retrieval 涉及文件信息
+        """
         # 获取当前事件循环
-        # document_id = []
+        document_id = []
         file_key = []
-        # filesinfo = []
+        filesinfo = []
         # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
         for key in run.file_ids:
             if len(key) == 36:
-                self.__keys.append(key)  # 添加文件id 作为检索
+                document_id.append(key)
             else:
-                file_key.append(
-                    key
-                )  ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
+                file_key.append(key)
 
         print(
             "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
         )
-        # print(document_id)
+        print(document_id)
         print(file_key)
         files = []
         # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
         if len(file_key) > 0:
             ## 获取文件信息
             files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
-            for file in files:
-                self.__keys.append(file.key)
             print(files)
         # r2r接口不提供多条件,否则上面没必要存在
-        """
         if len(document_id) > 0:
             filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
             # asyncio.run(
@@ -78,11 +81,8 @@ class FileSearchTool(BaseTool):
                 self.__filenames.append(file.get("title"))
                 self.__keys.append(file.get("id"))
             print(filesinfo)
-        """
 
         # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
-
-        # 读取assistant的数据,获取文件夹的id
         db_asst = AssistantService.get_assistant_sync(
             session=session, assistant_id=run.assistant_id
         )
@@ -95,10 +95,8 @@ class FileSearchTool(BaseTool):
                 .get("folder_ids")
             )
             print(asst_folder_ids)
-            # folder_fileinfo = []
+            folder_fileinfo = []
             if asst_folder_ids:
-                self.__dirkeys = asst_folder_ids
-            """
                 for fid in asst_folder_ids:
                     folder_fileinfo += FileService.list_documents(
                         id=fid, offset=0, limit=100
@@ -110,30 +108,36 @@ class FileSearchTool(BaseTool):
                 for file in folder_fileinfo:
                     self.__filenames.append(file.get("title"))
                     self.__keys.append(file.get("id"))
-            """
+
         # pre-cache data to prevent thread conflicts that may occur later on.
         print(
             "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
         )
-        print(self.__dirkeys)
-        """
+        print(files)
         for file in files:
             self.__filenames.append(file.filename)
             self.__keys.append(file.key)
-        """
         print(self.__keys)
 
-    # indexes: List[int],
-    def run(self, query: str) -> dict:
+    def run(self, indexes: List[int], query: str) -> dict:
+
+        file_keys = []
+
+        for index in indexes:
+            if index is not None:
+                file_key = self.__keys[index]
+                file_keys.append(file_key)
         print(
             "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
         )
-
-        print(self.__keys)
-        print(self.__dirkeys)
-        files = FileService.search_in_files(
-            query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
-        )
+        print(file_keys)
+        files = []
+        if len(file_keys) > 0:
+            # self.loop = asyncio.get_event_loop()
+            files = FileService.search_in_files(query=query, file_keys=file_keys)
+            # files = asyncio.run(
+            #    FileService.search_in_files(query=query, file_keys=file_keys)
+            # )
         print(files)
         return files
 
@@ -141,26 +145,16 @@ class FileSearchTool(BaseTool):
         """
         为 Retrieval 提供文件选择信息,用于 llm 调用抉择
         """
-        if (self.__keys and len(self.__keys) > 0) or (
-            self.__dirkeys and len(self.__dirkeys) > 0
-        ):
+        if len(self.__filenames) == 0:
+            return ""
+        else:
+            filenames_info = [
+                f"({index}){filename}"
+                for index, filename in enumerate(self.__filenames)
+            ]
             return (
-                "## 工具使用规范"
-                + "可调用工具:"
-                + "- retrieval:用于在文件库中搜索与问题相关的具体内容"
-                + "**调用规则**:"
-                + "1. 当问题涉及以下情况时必须调用本工具:"
-                + "   - 询问文件/文档中的具体内容"
-                + "   - 需要查找数据、条款或技术细节"
-                + '   - 用户明确要求"查文件"或"搜索资料"'
-                + "2. 调用时需遵循:"
-                + "   ```json"
-                + "   {"
-                + '     "action": "retrieval",'
-                + '     "action_input": {'
-                + '       "query": "精炼后的搜索语句,需包含至少2个关键要素(用户问题的原始上下文)"'
-                + "     }"
-                + "   }"
+                'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
+                + 'Each line represents a file in the format "(index)filename":\n'
+                + "\n".join(filenames_info)
+                + "\nMake sure to be extremely concise when using attached files. "
             )
-        else:
-            return ""