jack 4 mēneši atpakaļ
vecāks
revīzija
dc23c06917
1 mainītis faili ar 40 papildinājumiem un 12 dzēšanām
  1. 40 12
      app/providers/r2r.py

+ 40 - 12
app/providers/r2r.py

@@ -30,24 +30,52 @@ class R2R:
 
     def ingest_file(self, file_path: str, metadata: Optional[dict]):
         self._check_login()
-        ingest_response = asyncio.run(
-            self.client.documents.create(
-                file_path=file_path, metadata=metadata if metadata else None, id=None
+        loop = asyncio.get_event_loop()
+        if loop.is_running():
+            # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
+            ingest_response = loop.create_task(
+                self.client.documents.create(
+                    file_path=file_path,
+                    metadata=metadata if metadata else None,
+                    id=None,
+                )
+            )
+        else:
+            # 如果没有运行中的事件循环,使用 run_until_complete 来执行
+            ingest_response = loop.run_until_complete(
+                self.client.documents.create(
+                    file_path=file_path,
+                    metadata=metadata if metadata else None,
+                    id=None,
+                )
             )
-        )
         return ingest_response.get("results")
 
     def search(self, query: str, filters: dict[str, Any]):
         self._check_login()
-        search_response = asyncio.run(
-            self.client.retrieval.search(
-                query=query,
-                search_settings={
-                    "filters": filters,
-                    "limit": tool_settings.R2R_SEARCH_LIMIT,
-                },
+        loop = asyncio.get_event_loop()
+        if loop.is_running():
+            # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
+            search_response = loop.create_task(
+                self.client.retrieval.search(
+                    query=query,
+                    search_settings={
+                        "filters": filters,
+                        "limit": tool_settings.R2R_SEARCH_LIMIT,
+                    },
+                )
+            )
+        else:
+            # 如果没有运行中的事件循环,使用 run_until_complete 来执行
+            search_response = loop.run_until_complete(
+                self.client.retrieval.search(
+                    query=query,
+                    search_settings={
+                        "filters": filters,
+                        "limit": tool_settings.R2R_SEARCH_LIMIT,
+                    },
+                )
             )
-        )
         return search_response.get("results").get("chunk_search_results")
 
     def _check_login(self):