|
@@ -4,11 +4,10 @@ from r2r import R2RAsyncClient
|
|
from app.libs.util import verify_jwt_expiration
|
|
from app.libs.util import verify_jwt_expiration
|
|
from config.llm import tool_settings
|
|
from config.llm import tool_settings
|
|
|
|
|
|
-# import nest_asyncio
|
|
|
|
-import asyncio
|
|
|
|
|
|
+import nest_asyncio
|
|
|
|
|
|
-# Apply nest_asyncio to allow nested event loops
|
|
|
|
-# nest_asyncio.apply()
|
|
|
|
|
|
+# 使得异步代码可以在已运行的事件循环中嵌套
|
|
|
|
+nest_asyncio.apply()
|
|
|
|
|
|
|
|
|
|
class R2R:
|
|
class R2R:
|
|
@@ -28,57 +27,25 @@ class R2R:
|
|
)
|
|
)
|
|
print(self.client.access_token)
|
|
print(self.client.access_token)
|
|
|
|
|
|
- def ingest_file(self, file_path: str, metadata: Optional[dict]):
|
|
|
|
- self._check_login()
|
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
|
- if loop.is_running():
|
|
|
|
- # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
|
|
|
|
- ingest_response = loop.create_task(
|
|
|
|
- self.client.documents.create(
|
|
|
|
- file_path=file_path,
|
|
|
|
- metadata=metadata if metadata else None,
|
|
|
|
- id=None,
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- # 如果没有运行中的事件循环,使用 run_until_complete 来执行
|
|
|
|
- ingest_response = loop.run_until_complete(
|
|
|
|
- self.client.documents.create(
|
|
|
|
- file_path=file_path,
|
|
|
|
- metadata=metadata if metadata else None,
|
|
|
|
- id=None,
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- return ingest_response.get("results")
|
|
|
|
|
|
+ async def ingest_file(self, file_path: str, metadata: Optional[dict]):
|
|
|
|
+ await self._check_login()
|
|
|
|
+ return await self.client.documents.create(
|
|
|
|
+ file_path=file_path,
|
|
|
|
+ metadata=metadata if metadata else None,
|
|
|
|
+ id=None,
|
|
|
|
+ )
|
|
|
|
|
|
- def search(self, query: str, filters: dict[str, Any]):
|
|
|
|
- self._check_login()
|
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
|
- if loop.is_running():
|
|
|
|
- # 如果事件循环已经在运行,可以通过loop.create_task()调度任务
|
|
|
|
- search_response = loop.create_task(
|
|
|
|
- self.client.retrieval.search(
|
|
|
|
- query=query,
|
|
|
|
- search_settings={
|
|
|
|
- "filters": filters,
|
|
|
|
- "limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- # 如果没有运行中的事件循环,使用 run_until_complete 来执行
|
|
|
|
- search_response = loop.run_until_complete(
|
|
|
|
- self.client.retrieval.search(
|
|
|
|
- query=query,
|
|
|
|
- search_settings={
|
|
|
|
- "filters": filters,
|
|
|
|
- "limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- return search_response.get("results").get("chunk_search_results")
|
|
|
|
|
|
+ async def search(self, query: str, filters: dict[str, Any]):
|
|
|
|
+ await self._check_login()
|
|
|
|
+ return await self.client.retrieval.search(
|
|
|
|
+ query=query,
|
|
|
|
+ search_settings={
|
|
|
|
+ "filters": filters,
|
|
|
|
+ "limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
|
|
- def _check_login(self):
|
|
|
|
|
|
+ async def _check_login(self):
|
|
if not self.auth_enabled:
|
|
if not self.auth_enabled:
|
|
return
|
|
return
|
|
if not self.client.access_token and verify_jwt_expiration(
|
|
if not self.client.access_token and verify_jwt_expiration(
|
|
@@ -86,7 +53,7 @@ class R2R:
|
|
):
|
|
):
|
|
return
|
|
return
|
|
else:
|
|
else:
|
|
- asyncio.create_task(self.init())
|
|
|
|
|
|
+ await self.init()
|
|
|
|
|
|
|
|
|
|
# 创建 R2R 实例
|
|
# 创建 R2R 实例
|