jack 1 month ago
parent
commit
a54c305b4c

+ 7 - 3
app/core/tools/file_search_tool.py

@@ -30,11 +30,13 @@ class FileSearchTool(BaseTool):
         super().__init__()
         self.__filenames = []
         self.__keys = []
+        self.loop = None
 
     def configure(self, session: Session, run: Run, **kwargs):
         """
         置当前 Retrieval 涉及文件信息
         """
+        self.loop = asyncio.get_event_loop()  # 获取当前事件循环
         ## 获取文件信息
         files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
         # files = FileService.list_in_files(ids=run.file_ids, offset=0, limit=100)
@@ -53,8 +55,8 @@ class FileSearchTool(BaseTool):
             folder_fileinfo = []
             if asst_folder_ids:
                 for fid in asst_folder_ids:
-                    folder_fileinfo += FileService.list_documents(
-                        id=fid, offset=0, limit=100
+                    folder_fileinfo += self.loop.run_until_complete(
+                        FileService.list_documents(id=fid, offset=0, limit=100)
                     )
                 print(folder_fileinfo)
                 for file in folder_fileinfo:
@@ -78,7 +80,9 @@ class FileSearchTool(BaseTool):
             file_key = self.__keys[index]
             file_keys.append(file_key)
 
-        files = FileService.search_in_files(query=query, file_keys=file_keys)
+        files = self.loop.run_until_complete(
+            FileService.search_in_files(query=query, file_keys=file_keys)
+        )
         return files
 
     def instruction_supplement(self) -> str:

+ 3 - 3
app/services/file/impl/base.py

@@ -49,12 +49,12 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    def search_in_files(*, query: str, file_keys: List[str]) -> dict:
+    async def search_in_files(*, query: str, file_keys: List[str]) -> dict:
         pass
 
     @staticmethod
     @abstractmethod
-    def list_in_files(
+    async def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -63,7 +63,7 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    def list_documents(
+    async def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,

+ 3 - 3
app/services/file/impl/oss_file.py

@@ -93,7 +93,7 @@ class OSSFileService(BaseFileService):
         return DeleteResponse(id=file_id, deleted=True)
 
     @staticmethod
-    def search_in_files(query: str, file_keys: List[str]) -> dict:
+    async def search_in_files(query: str, file_keys: List[str]) -> dict:
         files = {}
         for file_key in file_keys:
             file_data = storage.load(file_key)
@@ -103,7 +103,7 @@ class OSSFileService(BaseFileService):
         return files
 
     @staticmethod
-    def list_in_files(
+    async def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -111,7 +111,7 @@ class OSSFileService(BaseFileService):
         return []
 
     @staticmethod
-    def list_documents(
+    async def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,

+ 12 - 6
app/services/file/impl/r2r_file.py

@@ -65,7 +65,7 @@ class R2RFileService(OSSFileService):
         return db_file
 
     @staticmethod
-    def search_in_files(
+    async def search_in_files(
         query: str, file_keys: List[str], folder_keys: List[str] = None
     ) -> dict:
         files = {}
@@ -88,9 +88,11 @@ class R2RFileService(OSSFileService):
         loop = asyncio.get_event_loop()  # 获取当前事件循环
         loop.run_until_complete(r2r.init())  # 确保 r2r 已初始化
         search_results = loop.run_until_complete(r2r.search(query, filters=filters))
-        """
         asyncio.run(r2r.init())
         search_results = asyncio.run(r2r.search(query, filters=filters))
+        """
+        await r2r.init()
+        search_results = await r2r.search(query, filters=filters)
         if not search_results:
             return files
 
@@ -105,7 +107,7 @@ class R2RFileService(OSSFileService):
         return files
 
     @staticmethod
-    def list_in_files(
+    async def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -116,13 +118,15 @@ class R2RFileService(OSSFileService):
         list_results = loop.run_until_complete(
             r2r.list(ids=ids, offset=offset, limit=limit)
         )
-        """
         asyncio.run(r2r.init())
         list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
+        """
+        await r2r.init()
+        list_results = await r2r.list(ids=ids, offset=offset, limit=limit)
         return list_results
 
     @staticmethod
-    def list_documents(
+    async def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,
@@ -133,11 +137,13 @@ class R2RFileService(OSSFileService):
         list_results = loop.run_until_complete(
             r2r.list_documents(id=id, offset=offset, limit=limit)
         )
-        """
         asyncio.run(r2r.init())
         list_results = asyncio.run(
             r2r.list_documents(id=id, offset=offset, limit=limit)
         )
+        """
+        await r2r.init()
+        list_results = await r2r.list_documents(id=id, offset=offset, limit=limit)
         return list_results
 
     # TODO 删除s3&r2r文件