jack 1 ay önce
ebeveyn
işleme
33e9e4733a

+ 4 - 18
app/core/tools/file_search_tool.py

@@ -48,11 +48,6 @@ class FileSearchTool(BaseTool):
         """
         置当前 Retrieval 涉及文件信息
         """
-        try:
-            loop = asyncio.get_running_loop()  # 检查是否有运行的事件循环
-        except RuntimeError:
-            loop = asyncio.get_event_loop()
-            print("事件循环未运行,手动启动")
         # 获取当前事件循环
         document_id = []
         file_key = []
@@ -77,9 +72,7 @@ class FileSearchTool(BaseTool):
             print(files)
         # r2r接口不提供多条件,否则上面没必要存在
         if len(document_id) > 0:
-            filesinfo += loop.run_until_complete(
-                FileService.list_in_files(ids=document_id, offset=0, limit=100)
-            )
+            filesinfo += FileService.list_in_files(ids=document_id, offset=0, limit=100)
             # asyncio.run(
             #    FileService.list_in_files(ids=document_id, offset=0, limit=100)
             # )
@@ -104,8 +97,8 @@ class FileSearchTool(BaseTool):
             folder_fileinfo = []
             if asst_folder_ids:
                 for fid in asst_folder_ids:
-                    folder_fileinfo += loop.run_until_complete(
-                        FileService.list_documents(id=fid, offset=0, limit=100)
+                    folder_fileinfo += FileService.list_documents(
+                        id=fid, offset=0, limit=100
                     )
                     # folder_fileinfo += asyncio.run(
                     #    FileService.list_documents(id=fid, offset=0, limit=100)
@@ -139,15 +132,8 @@ class FileSearchTool(BaseTool):
         print(file_keys)
         files = []
         if len(file_keys) > 0:
-            try:
-                loop = asyncio.get_running_loop()  # 检查是否有运行的事件循环
-            except RuntimeError:
-                loop = asyncio.get_event_loop()
-            print("事件循环未运行,手动启动")
             # self.loop = asyncio.get_event_loop()
-            files = loop.run_until_complete(
-                FileService.search_in_files(query=query, file_keys=file_keys)
-            )
+            files = FileService.search_in_files(query=query, file_keys=file_keys)
             # files = asyncio.run(
             #    FileService.search_in_files(query=query, file_keys=file_keys)
             # )

+ 37 - 9
app/providers/r2r.py

@@ -1,5 +1,6 @@
 from typing import Optional, Any
 from r2r import R2RAsyncClient
+from r2r import R2RClient
 from fastapi import UploadFile
 from app.libs.util import verify_jwt_expiration
 from config.llm import tool_settings
@@ -11,10 +12,27 @@ nest_asyncio.apply()
 
 class R2R:
     client: R2RAsyncClient
+    client_sync: R2RClient
 
     def __init__(self):
         self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
         self.client = None
+        self.client_sync = None
+
+    def init_sync(self):
+        if not self.auth_enabled:
+            return
+        if not self.client_sync:
+            self.client_sync = R2RClient(tool_settings.R2R_BASE_URL, "/v3")
+        print(
+            "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
+        )
+        print(tool_settings.R2R_USERNAME)
+        print(tool_settings.R2R_PASSWORD)
+        self.client_sync.users.login(
+            tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
+        )
+        print(self.client.access_token)
 
     async def init(self):
         if not self.auth_enabled:
@@ -47,12 +65,12 @@ class R2R:
             id=None,
         )
 
-    async def search(self, query: str, filters: dict[str, Any]):
-        await self._check_login()
+    def search(self, query: str, filters: dict[str, Any]):
+        self._check_login_sync()
         print(
             "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
         )
-        search_response = await self.client.retrieval.search(
+        search_response = self.client_sync.retrieval.search(
             query=query,
             search_settings={
                 "filters": filters,
@@ -63,13 +81,13 @@ class R2R:
         print(search_response.get("results"))
         return search_response.get("results").get("chunk_search_results")
 
-    async def list(
+    def list(
         self,
         ids: Optional[list[str]] = None,
         offset: Optional[int] = 0,
         limit: Optional[int] = 100,
     ):
-        await self._check_login()
+        self._check_login_sync()
         print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaassssssssssssssssssssssssssssssssssss")
         print(ids)
 
@@ -80,7 +98,7 @@ class R2R:
         """
         print("listlistlistlistlistlistlistlistlistlistlistlistlistlistlistlistlist")
         if len(ids) > 0:
-            listed = await self.client.documents.list(
+            listed = self.client_sync.documents.list(
                 ids=ids, limit=limit, offset=offset
             )
             print(listed.get("results"))
@@ -88,13 +106,13 @@ class R2R:
         else:
             return []
 
-    async def list_documents(
+    def list_documents(
         self,
         id: Optional[str] = "",
         offset: Optional[int] = 0,
         limit: Optional[int] = 100,
     ):
-        await self._check_login()
+        self._check_login_sync()
 
         """
             docs = client.collections.list_documents(empty_coll_id).results
@@ -105,7 +123,7 @@ class R2R:
         )
         if id != "":
             try:
-                listed = await self.client.collections.list_documents(
+                listed = self.client_sync.collections.list_documents(
                     id=id, limit=limit, offset=offset
                 )
                 print(listed.get("results"))
@@ -128,6 +146,16 @@ class R2R:
         else:
             await self.init()
 
+    def _check_login_sync(self):
+        if not self.auth_enabled:
+            return
+        if not self.client.access_token and verify_jwt_expiration(
+            self.client.access_token
+        ):
+            return
+        else:
+            self.init_sync()
+
 
 # 创建 R2R 实例
 r2r = R2R()

+ 3 - 3
app/services/file/impl/base.py

@@ -49,12 +49,12 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    async def search_in_files(*, query: str, file_keys: List[str]) -> dict:
+    def search_in_files(*, query: str, file_keys: List[str]) -> dict:
         pass
 
     @staticmethod
     @abstractmethod
-    async def list_in_files(
+    def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -63,7 +63,7 @@ class BaseFileService(ABC):
 
     @staticmethod
     @abstractmethod
-    async def list_documents(
+    def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,

+ 3 - 3
app/services/file/impl/oss_file.py

@@ -93,7 +93,7 @@ class OSSFileService(BaseFileService):
         return DeleteResponse(id=file_id, deleted=True)
 
     @staticmethod
-    async def search_in_files(query: str, file_keys: List[str]) -> dict:
+    def search_in_files(query: str, file_keys: List[str]) -> dict:
         files = {}
         for file_key in file_keys:
             file_data = storage.load(file_key)
@@ -103,7 +103,7 @@ class OSSFileService(BaseFileService):
         return files
 
     @staticmethod
-    async def list_in_files(
+    def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -111,7 +111,7 @@ class OSSFileService(BaseFileService):
         return []
 
     @staticmethod
-    async def list_documents(
+    def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,

+ 9 - 9
app/services/file/impl/r2r_file.py

@@ -70,7 +70,7 @@ class R2RFileService(OSSFileService):
         return db_file
 
     @staticmethod
-    async def search_in_files(
+    def search_in_files(
         query: str, file_keys: List[str], folder_keys: List[str] = None
     ) -> dict:
         files = {}
@@ -117,8 +117,8 @@ class R2RFileService(OSSFileService):
             r2r.search(query, filters={"file_key": {"$in": file_keys}})
         )
         """
-        await r2r.init()
-        search_results = await r2r.search(query, filters=filters)
+        r2r.init_sync()
+        search_results = r2r.search(query, filters=filters)
         if not search_results:
             return files
 
@@ -133,7 +133,7 @@ class R2RFileService(OSSFileService):
         return files
 
     @staticmethod
-    async def list_in_files(
+    def list_in_files(
         ids: list[str] = None,
         offset: int = 0,
         limit: int = 100,
@@ -147,12 +147,12 @@ class R2RFileService(OSSFileService):
         asyncio.run(r2r.init())
         list_results = asyncio.run(r2r.list(ids=ids, offset=offset, limit=limit))
         """
-        await r2r.init()
-        list_results = await r2r.list(ids=ids, offset=offset, limit=limit)
+        r2r.init_sync()
+        list_results = r2r.list(ids=ids, offset=offset, limit=limit)
         return list_results
 
     @staticmethod
-    async def list_documents(
+    def list_documents(
         id: str = "",
         offset: int = 0,
         limit: int = 100,
@@ -168,8 +168,8 @@ class R2RFileService(OSSFileService):
             r2r.list_documents(id=id, offset=offset, limit=limit)
         )
         """
-        await r2r.init()
-        list_results = await r2r.list_documents(id=id, offset=offset, limit=limit)
+        r2r.init_sync()
+        list_results = r2r.list_documents(id=id, offset=offset, limit=limit)
         return list_results
 
     # TODO 删除s3&r2r文件