|
@@ -1,5 +1,6 @@
|
|
|
from typing import Optional, Any
|
|
|
from r2r import R2RAsyncClient
|
|
|
+from r2r import R2RClient
|
|
|
from fastapi import UploadFile
|
|
|
from app.libs.util import verify_jwt_expiration
|
|
|
from config.llm import tool_settings
|
|
@@ -11,10 +12,27 @@ nest_asyncio.apply()
|
|
|
|
|
|
class R2R:
|
|
|
client: R2RAsyncClient
|
|
|
+ client_sync: R2RClient
|
|
|
|
|
|
def __init__(self):
|
|
|
self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
|
|
|
self.client = None
|
|
|
+ self.client_sync = None
|
|
|
+
|
|
|
+ def init_sync(self):
|
|
|
+ if not self.auth_enabled:
|
|
|
+ return
|
|
|
+ if not self.client_sync:
|
|
|
+ self.client_sync = R2RClient(tool_settings.R2R_BASE_URL, "/v3")
|
|
|
+ print(
|
|
|
+ "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
|
|
|
+ )
|
|
|
+ print(tool_settings.R2R_USERNAME)
|
|
|
+ print(tool_settings.R2R_PASSWORD)
|
|
|
+ self.client_sync.users.login(
|
|
|
+ tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
|
|
|
+ )
|
|
|
+ print(self.client.access_token)
|
|
|
|
|
|
async def init(self):
|
|
|
if not self.auth_enabled:
|
|
@@ -47,12 +65,12 @@ class R2R:
|
|
|
id=None,
|
|
|
)
|
|
|
|
|
|
- async def search(self, query: str, filters: dict[str, Any]):
|
|
|
- await self._check_login()
|
|
|
+ def search(self, query: str, filters: dict[str, Any]):
|
|
|
+ self._check_login_sync()
|
|
|
print(
|
|
|
"aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
|
|
|
)
|
|
|
- search_response = await self.client.retrieval.search(
|
|
|
+ search_response = self.client_sync.retrieval.search(
|
|
|
query=query,
|
|
|
search_settings={
|
|
|
"filters": filters,
|
|
@@ -63,13 +81,13 @@ class R2R:
|
|
|
print(search_response.get("results"))
|
|
|
return search_response.get("results").get("chunk_search_results")
|
|
|
|
|
|
- async def list(
|
|
|
+ def list(
|
|
|
self,
|
|
|
ids: Optional[list[str]] = None,
|
|
|
offset: Optional[int] = 0,
|
|
|
limit: Optional[int] = 100,
|
|
|
):
|
|
|
- await self._check_login()
|
|
|
+ self._check_login_sync()
|
|
|
print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaassssssssssssssssssssssssssssssssssss")
|
|
|
print(ids)
|
|
|
|
|
@@ -80,7 +98,7 @@ class R2R:
|
|
|
"""
|
|
|
print("listlistlistlistlistlistlistlistlistlistlistlistlistlistlistlistlist")
|
|
|
if len(ids) > 0:
|
|
|
- listed = await self.client.documents.list(
|
|
|
+ listed = self.client_sync.documents.list(
|
|
|
ids=ids, limit=limit, offset=offset
|
|
|
)
|
|
|
print(listed.get("results"))
|
|
@@ -88,13 +106,13 @@ class R2R:
|
|
|
else:
|
|
|
return []
|
|
|
|
|
|
- async def list_documents(
|
|
|
+ def list_documents(
|
|
|
self,
|
|
|
id: Optional[str] = "",
|
|
|
offset: Optional[int] = 0,
|
|
|
limit: Optional[int] = 100,
|
|
|
):
|
|
|
- await self._check_login()
|
|
|
+ self._check_login_sync()
|
|
|
|
|
|
"""
|
|
|
docs = client.collections.list_documents(empty_coll_id).results
|
|
@@ -105,7 +123,7 @@ class R2R:
|
|
|
)
|
|
|
if id != "":
|
|
|
try:
|
|
|
- listed = await self.client.collections.list_documents(
|
|
|
+ listed = self.client_sync.collections.list_documents(
|
|
|
id=id, limit=limit, offset=offset
|
|
|
)
|
|
|
print(listed.get("results"))
|
|
@@ -128,6 +146,16 @@ class R2R:
|
|
|
else:
|
|
|
await self.init()
|
|
|
|
|
|
+ def _check_login_sync(self):
|
|
|
+ if not self.auth_enabled:
|
|
|
+ return
|
|
|
+ if not self.client.access_token and verify_jwt_expiration(
|
|
|
+ self.client.access_token
|
|
|
+ ):
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ self.init_sync()
|
|
|
+
|
|
|
|
|
|
# 创建 R2R 实例
|
|
|
r2r = R2R()
|