jack há 4 meses atrás
pai
commit
a8ad9a3ea8
2 ficheiros alterados com 12 adições e 8 exclusões
  1. 9 6
      app/providers/r2r.py
  2. 3 2
      app/services/file/impl/r2r_file.py

+ 9 - 6
app/providers/r2r.py

@@ -17,15 +17,13 @@ class R2R:
         self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
         self.client = None
 
-    def init(self):
+    async def init(self):
         if not self.auth_enabled:
             return
         if not self.client:
             self.client = R2RClient(tool_settings.R2R_BASE_URL)
-            asyncio.run(
-                self.client.users.login(
-                    tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
-                )
+            await self.client.users.login(
+                tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
             )
 
     def ingest_file(self, file_path: str, metadata: Optional[dict]):
@@ -52,8 +50,13 @@ class R2R:
         if verify_jwt_expiration(self.client.access_token):
             return
         else:
-            self.init()
+            asyncio.create_task(self.init())
 
 
 # 创建 R2R 实例
 r2r = R2R()
+
+
+# 在您的应用程序启动时调用 initialize_r2r()
+async def initialize_r2r():
+    await r2r.init()

+ 3 - 2
app/services/file/impl/r2r_file.py

@@ -11,6 +11,7 @@ from app.models import File
 from app.providers.r2r import r2r
 from app.providers.storage import storage
 from app.services.file.impl.oss_file import OSSFileService
+import asyncio
 
 
 class R2RFileService(OSSFileService):
@@ -42,7 +43,7 @@ class R2RFileService(OSSFileService):
                     await f.write(content)
 
             storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
-            r2r.init()
+            await r2r.init()
             r2r.ingest_file(file_path=tmp_file_path, metadata={"file_key": file_key})
 
         # 存储
@@ -57,7 +58,7 @@ class R2RFileService(OSSFileService):
     @staticmethod
     def search_in_files(query: str, file_keys: List[str]) -> dict:
         files = {}
-        r2r.init()
+        asyncio.create_task(r2r.init())
         search_results = r2r.search(query, filters={"file_key": {"$in": file_keys}})
         if not search_results:
             return files