|
@@ -1,6 +1,4 @@
|
|
from typing import Optional, Any
|
|
from typing import Optional, Any
|
|
-import pytest
|
|
|
|
-
|
|
|
|
from r2r import R2RClient
|
|
from r2r import R2RClient
|
|
|
|
|
|
from app.libs.util import verify_jwt_expiration
|
|
from app.libs.util import verify_jwt_expiration
|
|
@@ -15,16 +13,20 @@ nest_asyncio.apply()
|
|
class R2R:
|
|
class R2R:
|
|
client: R2RClient
|
|
client: R2RClient
|
|
|
|
|
|
- def init(self):
|
|
|
|
- # self.client = R2RClient(tool_settings.R2R_BASE_URL)
|
|
|
|
|
|
+ def __init__(self):
|
|
self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
|
|
self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
|
|
self.client = None
|
|
self.client = None
|
|
- loop = asyncio.get_event_loop()
|
|
|
|
- if loop.is_running():
|
|
|
|
- return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
|
|
|
|
- else:
|
|
|
|
- return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
|
|
|
|
- # loop.create_task(self._login())
|
|
|
|
|
|
+
|
|
|
|
+ def init(self):
|
|
|
|
+ if not self.auth_enabled:
|
|
|
|
+ return
|
|
|
|
+ if not self.client:
|
|
|
|
+ self.client = R2RClient(tool_settings.R2R_BASE_URL)
|
|
|
|
+ asyncio.run(
|
|
|
|
+ self.client.users.login(
|
|
|
|
+ tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
|
|
|
|
+ )
|
|
|
|
+ )
|
|
|
|
|
|
def ingest_file(self, file_path: str, metadata: Optional[dict]):
|
|
def ingest_file(self, file_path: str, metadata: Optional[dict]):
|
|
self._check_login()
|
|
self._check_login()
|
|
@@ -35,52 +37,23 @@ class R2R:
|
|
|
|
|
|
def search(self, query: str, filters: dict[str, Any]):
|
|
def search(self, query: str, filters: dict[str, Any]):
|
|
self._check_login()
|
|
self._check_login()
|
|
- print(
|
|
|
|
- "aaaaaaacccccccccccccccccccccccccccccccccccccccccvvvvvvvvvvvvvvvvvvvvvvvvvvvvv"
|
|
|
|
- )
|
|
|
|
- print(filters)
|
|
|
|
- print(tool_settings.R2R_SEARCH_LIMIT)
|
|
|
|
search_response = self.client.retrieval.search(
|
|
search_response = self.client.retrieval.search(
|
|
query=query,
|
|
query=query,
|
|
search_settings={
|
|
search_settings={
|
|
"filters": filters,
|
|
"filters": filters,
|
|
"limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
"limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
- # ,"do_hybrid_search": True,
|
|
|
|
},
|
|
},
|
|
)
|
|
)
|
|
- print(search_response)
|
|
|
|
return search_response.get("results").get("chunk_search_results")
|
|
return search_response.get("results").get("chunk_search_results")
|
|
|
|
|
|
- # @pytest.fixture(scope="session")
|
|
|
|
- async def _login(self):
|
|
|
|
- if not self.auth_enabled:
|
|
|
|
- return
|
|
|
|
- print(
|
|
|
|
- "client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>client=>"
|
|
|
|
- )
|
|
|
|
- print(self.client)
|
|
|
|
- if not self.client:
|
|
|
|
- self.client = R2RClient(tool_settings.R2R_BASE_URL)
|
|
|
|
- result = await self.client.users.login(
|
|
|
|
- tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
|
|
|
|
- ) # 同步调用异步函数
|
|
|
|
- # return self.client
|
|
|
|
-
|
|
|
|
def _check_login(self):
|
|
def _check_login(self):
|
|
if not self.auth_enabled:
|
|
if not self.auth_enabled:
|
|
return
|
|
return
|
|
if verify_jwt_expiration(self.client.access_token):
|
|
if verify_jwt_expiration(self.client.access_token):
|
|
return
|
|
return
|
|
else:
|
|
else:
|
|
- loop = asyncio.get_event_loop()
|
|
|
|
- if loop.is_running():
|
|
|
|
- return loop.create_task(self._login()) # 在现有事件循环中运行异步任务
|
|
|
|
- else:
|
|
|
|
- return asyncio.run(self._login()) # 如果没有事件循环则创建一个新的
|
|
|
|
- # loop.create_task(self._login())
|
|
|
|
- # self._login()
|
|
|
|
|
|
+ self.init()
|
|
|
|
|
|
|
|
|
|
|
|
+# 创建 R2R 实例
|
|
r2r = R2R()
|
|
r2r = R2R()
|
|
-
|
|
|
|
-r2r.init() # 运行异步函数
|
|
|