jack 4 miesięcy temu
rodzic
commit
182c7e4d3f

+ 13 - 12
app/core/tools/file_search_tool.py

@@ -14,6 +14,11 @@ import nest_asyncio
 nest_asyncio.apply()
 
 # asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+try:
+    loop = asyncio.get_running_loop()  # 检查是否有运行的事件循环
+except RuntimeError:
+    loop = asyncio.get_event_loop()
+    print("事件循环未运行,手动启动")
 
 
 class FileSearchToolInput(BaseModel):
@@ -48,7 +53,8 @@ class FileSearchTool(BaseTool):
         """
         置当前 Retrieval 涉及文件信息
         """
-        self.loop = asyncio.get_event_loop()  # 获取当前事件循环
+
+        # 获取当前事件循环
         document_id = []
         file_key = []
         filesinfo = []
@@ -72,11 +78,9 @@ class FileSearchTool(BaseTool):
             print(files)
         # r2r接口不提供多条件,否则上面没必要存在
         if len(document_id) > 0:
-            future = asyncio.run_coroutine_threadsafe(
-                FileService.list_in_files(ids=document_id, offset=0, limit=100),
-                self.loop,
+            filesinfo += loop.run_until_complete(
+                FileService.list_in_files(ids=document_id, offset=0, limit=100)
             )
-            filesinfo += future.result()
             # asyncio.run(
             #    FileService.list_in_files(ids=document_id, offset=0, limit=100)
             # )
@@ -101,11 +105,9 @@ class FileSearchTool(BaseTool):
             folder_fileinfo = []
             if asst_folder_ids:
                 for fid in asst_folder_ids:
-                    future = asyncio.run_coroutine_threadsafe(
-                        FileService.list_documents(id=fid, offset=0, limit=100),
-                        self.loop,
+                    folder_fileinfo += loop.run_until_complete(
+                        FileService.list_documents(id=fid, offset=0, limit=100)
                     )
-                    folder_fileinfo += future.result()
                     # folder_fileinfo += asyncio.run(
                     #    FileService.list_documents(id=fid, offset=0, limit=100)
                     # )
@@ -137,10 +139,9 @@ class FileSearchTool(BaseTool):
         files = []
         if len(file_keys) > 0:
             # self.loop = asyncio.get_event_loop()
-            future = asyncio.run_coroutine_threadsafe(
-                FileService.search_in_files(query=query, file_keys=file_keys), self.loop
+            files += loop.run_until_complete(
+                FileService.search_in_files(query=query, file_keys=file_keys)
             )
-            files += future.result()
             # files = asyncio.run(
             #    FileService.search_in_files(query=query, file_keys=file_keys)
             # )

+ 4 - 4
app/services/file/impl/r2r_file.py

@@ -51,7 +51,7 @@ class R2RFileService(OSSFileService):
                     await f.write(content)
 
             # storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
-            await r2r.init()
+            # await r2r.init()
             fileinfo = await r2r.ingest_file(
                 file_path=tmp_file_path,
                 metadata={"file_key": file_key, "title": file.filename},
@@ -117,7 +117,7 @@ class R2RFileService(OSSFileService):
             r2r.search(query, filters={"file_key": {"$in": file_keys}})
         )
         """
-        await r2r.init()
+        # await r2r.init()
         search_results = await r2r.search(query, filters=filters)
         if not search_results:
             return files
@@ -147,7 +147,7 @@ class R2RFileService(OSSFileService):
         asyncio.run(r2r.init())
         list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
         """
-        await r2r.init()
+        # await r2r.init()
         list_results = await r2r.list(ids=ids, offset=offset, limit=limit)
         return list_results
 
@@ -168,7 +168,7 @@ class R2RFileService(OSSFileService):
             r2r.list_documents(id=id, offset=offset, limit=limit)
         )
         """
-        await r2r.init()
+        # await r2r.init()
         list_results = await r2r.list_documents(id=id, offset=offset, limit=limit)
         return list_results