jack 1 kuukausi sitten
vanhempi
commit
1519a0593d

+ 138 - 6
app/core/tools/file_search_tool.py

@@ -16,12 +16,17 @@ nest_asyncio.apply()
 
 # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
 
-
+"""
 class FileSearchToolInput(BaseModel):
     indexes: List[int] = Field(
         ..., description="file index list to look up in retrieval"
     )
     query: str = Field(..., description="query to look up in retrieval")
+"""
+
+
+class FileSearchToolInput(BaseModel):
+    query: str = Field(..., description="query to look up in retrieval")
 
 
 class FileSearchTool(BaseTool):
@@ -37,18 +42,142 @@ class FileSearchTool(BaseTool):
         super().__init__()
         self.__filenames = []
         self.__keys = []
+        self.__dirkeys = []
         self.loop = None
 
     def configure(self, session: Session, run: Run, **kwargs):
+        # 获取当前事件循环
+        # document_id = []
+        file_key = []
+        # filesinfo = []
+        # 后语要从知识库里选择文件,所以在openassistant的数据库里可能不存在
+        for key in run.file_ids:
+            if len(key) == 36:
+                self.__keys.append(key)  # 添加文件id 作为检索
+            else:
+                file_key.append(
+                    key
+                )  ## assiatant的id数据,在r2r里没办法检索需要提取filekey字段
+
+        print(
+            "document_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_iddocument_id"
+        )
+        # print(document_id)
+        print(file_key)
+        files = []
+        # 这种情况是uuid.ex 这种格式的在最早的时候存在的,后续要去掉
+        if len(file_key) > 0:
+            ## 获取文件信息
+            files = FileService.get_file_list_by_ids(session=session, file_ids=file_key)
+            for file in files:
+                self.__keys.append(file.key)
+            print(files)
+        # r2r接口不提供多条件,否则上面没必要存在
+        """
+        if len(document_id) > 0:
+            filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
+            # asyncio.run(
+            #    FileService.list_in_files(ids=document_id, offset=0, limit=100)
+            # )
+            for file in filesinfo:
+                self.__filenames.append(file.get("title"))
+                self.__keys.append(file.get("id"))
+            print(filesinfo)
+        """
+
+        # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
+
+        # 读取assistant的数据,获取文件夹的id
+        db_asst = AssistantService.get_assistant_sync(
+            session=session, assistant_id=run.assistant_id
+        )
+
+        if db_asst.tool_resources and "file_search" in db_asst.tool_resources:
+            ##{"file_search": {"vector_store_ids": [{"file_ids": []}]}}
+            asst_folder_ids = (
+                db_asst.tool_resources.get("file_search")
+                .get("vector_stores")[0]
+                .get("folder_ids")
+            )
+            print(asst_folder_ids)
+            # folder_fileinfo = []
+            if asst_folder_ids:
+                self.__dirkeys = asst_folder_ids
+            """
+                for fid in asst_folder_ids:
+                    folder_fileinfo += FileService.list_documents(
+                        id=fid, offset=0, limit=100
+                    )
+                    # folder_fileinfo += asyncio.run(
+                    #    FileService.list_documents(id=fid, offset=0, limit=100)
+                    # )
+                print(folder_fileinfo)
+                for file in folder_fileinfo:
+                    self.__filenames.append(file.get("title"))
+                    self.__keys.append(file.get("id"))
+            """
+        # pre-cache data to prevent thread conflicts that may occur later on.
+        print(
+            "---------ssssssssssss-----------------sssssssssssss---------------ssssssssssssss-------------sssssssssssss-------------ss-------"
+        )
+        print(self.__dirkeys)
         """
-                # 提交任务到事件循环
-        future = asyncio.run_coroutine_threadsafe(async_task(), loop)
-        # 阻塞等待结果
-        result = future.result()
+        for file in files:
+            self.__filenames.append(file.filename)
+            self.__keys.append(file.key)
         """
+        print(self.__keys)
+
+    # indexes: List[int],
+    def run(self, query: str) -> dict:
+
+        print(
+            "file_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keysfile_keys"
+        )
+
+        print(self.__keys)
+        print(self.__dirkeys)
+        files = FileService.search_in_files(
+            query=query, file_keys=self.__keys, folder_keys=self.__dirkeys
+        )
+        print(files)
+        return files
+
         """
-        置当前 Retrieval 涉及文件信息
+        file_keys = []
+        for index in indexes:
+            if index is not None:
+                file_key = self.__keys[index]
+                file_keys.append(file_key)
+        
+
+        
+        print(file_keys)
+        files = []
+        if len(file_keys) > 0:
+            # self.loop = asyncio.get_event_loop()
+            
+            # files = asyncio.run(
+            #    FileService.search_in_files(query=query, file_keys=file_keys)
+            # )
+        print(files)
+        return files
+        """
+
+    def instruction_supplement(self) -> str:
+        """
+        为 Retrieval 提供文件选择信息,用于 llm 调用抉择
         """
+        return (
+            'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
+            # + 'Each line represents a file in the format "(index)filename":\n'
+            # + "\n".join(filenames_info)
+            + "\nMake sure to be extremely concise when using attached files. "
+        )
+
+
+'''
+    def configure(self, session: Session, run: Run, **kwargs):
         # 获取当前事件循环
         document_id = []
         file_key = []
@@ -119,6 +248,8 @@ class FileSearchTool(BaseTool):
             self.__keys.append(file.key)
         print(self.__keys)
 
+
+
     def run(self, indexes: List[int], query: str) -> dict:
 
         file_keys = []
@@ -158,3 +289,4 @@ class FileSearchTool(BaseTool):
                 + "\n".join(filenames_info)
                 + "\nMake sure to be extremely concise when using attached files. "
             )
+'''

+ 3 - 1
app/services/file/impl/base.py

@@ -49,7 +49,9 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    def search_in_files(*, query: str, file_keys: List[str]) -> dict:
+    def search_in_files(
+        *, query: str, file_keys: List[str], folder_keys: List[str] = None
+    ) -> dict:
         pass
 
     @staticmethod

+ 3 - 1
app/services/file/impl/oss_file.py

@@ -93,7 +93,9 @@ class OSSFileService(BaseFileService):
         return DeleteResponse(id=file_id, deleted=True)
 
     @staticmethod
-    def search_in_files(query: str, file_keys: List[str]) -> dict:
+    def search_in_files(
+        query: str, file_keys: List[str], folder_keys: List[str] = None
+    ) -> dict:
         files = {}
         for file_key in file_keys:
             file_data = storage.load(file_key)