jack 4 months ago
parent
commit
424a9c330e
2 changed files with 27 additions and 56 deletions
  1. 21 54
      app/providers/r2r.py
  2. 6 2
      app/services/file/impl/r2r_file.py

+ 21 - 54
app/providers/r2r.py

@@ -4,11 +4,10 @@ from r2r import R2RAsyncClient
 from app.libs.util import verify_jwt_expiration
 from config.llm import tool_settings
 
-# import nest_asyncio
-import asyncio
+import nest_asyncio
 
-# Apply nest_asyncio to allow nested event loops
-# nest_asyncio.apply()
+# 使得异步代码可以在已运行的事件循环中嵌套
+nest_asyncio.apply()
 
 
 class R2R:
@@ -28,57 +27,25 @@ class R2R:
         )
         print(self.client.access_token)
 
-    def ingest_file(self, file_path: str, metadata: Optional[dict]):
-        self._check_login()
-        loop = asyncio.get_event_loop()
-        if loop.is_running():
-            # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
-            ingest_response = loop.create_task(
-                self.client.documents.create(
-                    file_path=file_path,
-                    metadata=metadata if metadata else None,
-                    id=None,
-                )
-            )
-        else:
-            # 如果没有运行中的事件循环,使用 run_until_complete 来执行
-            ingest_response = loop.run_until_complete(
-                self.client.documents.create(
-                    file_path=file_path,
-                    metadata=metadata if metadata else None,
-                    id=None,
-                )
-            )
-        return ingest_response.get("results")
+    async def ingest_file(self, file_path: str, metadata: Optional[dict]):
+        await self._check_login()
+        return await self.client.documents.create(
+            file_path=file_path,
+            metadata=metadata if metadata else None,
+            id=None,
+        )
 
-    def search(self, query: str, filters: dict[str, Any]):
-        self._check_login()
-        loop = asyncio.get_event_loop()
-        if loop.is_running():
-            # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
-            search_response = loop.create_task(
-                self.client.retrieval.search(
-                    query=query,
-                    search_settings={
-                        "filters": filters,
-                        "limit": tool_settings.R2R_SEARCH_LIMIT,
-                    },
-                )
-            )
-        else:
-            # 如果没有运行中的事件循环,使用 run_until_complete 来执行
-            search_response = loop.run_until_complete(
-                self.client.retrieval.search(
-                    query=query,
-                    search_settings={
-                        "filters": filters,
-                        "limit": tool_settings.R2R_SEARCH_LIMIT,
-                    },
-                )
-            )
-        return search_response.get("results").get("chunk_search_results")
+    async def search(self, query: str, filters: dict[str, Any]):
+        await self._check_login()
+        return await self.client.retrieval.search(
+            query=query,
+            search_settings={
+                "filters": filters,
+                "limit": tool_settings.R2R_SEARCH_LIMIT,
+            },
+        )
 
-    def _check_login(self):
+    async def _check_login(self):
         if not self.auth_enabled:
             return
         if not self.client.access_token and verify_jwt_expiration(
@@ -86,7 +53,7 @@ class R2R:
         ):
             return
         else:
-            asyncio.create_task(self.init())
+            await self.init()
 
 
 # 创建 R2R 实例

+ 6 - 2
app/services/file/impl/r2r_file.py

@@ -44,7 +44,9 @@ class R2RFileService(OSSFileService):
 
             storage.save_from_path(filename=file_key, local_file_path=tmp_file_path)
             await r2r.init()
-            r2r.ingest_file(file_path=tmp_file_path, metadata={"file_key": file_key})
+            await r2r.ingest_file(
+                file_path=tmp_file_path, metadata={"file_key": file_key}
+            )
 
         # 存储
         db_file = File(
@@ -59,7 +61,9 @@ class R2RFileService(OSSFileService):
     def search_in_files(query: str, file_keys: List[str]) -> dict:
         files = {}
         asyncio.create_task(r2r.init())
-        search_results = r2r.search(query, filters={"file_key": {"$in": file_keys}})
+        search_results = asyncio.create_task(
+            r2r.search(query, filters={"file_key": {"$in": file_keys}})
+        )
         if not search_results:
             return files