|
@@ -1,5 +1,5 @@
|
|
|
from typing import Optional, Any
|
|
|
-from r2r import R2RClient
|
|
|
+from r2r import R2RAsyncClient
|
|
|
|
|
|
from app.libs.util import verify_jwt_expiration
|
|
|
from config.llm import tool_settings
|
|
@@ -12,7 +12,7 @@ import asyncio
|
|
|
|
|
|
|
|
|
class R2R:
|
|
|
- client: R2RClient
|
|
|
+ client: R2RAsyncClient
|
|
|
|
|
|
def __init__(self):
|
|
|
self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
|
|
@@ -22,25 +22,31 @@ class R2R:
|
|
|
if not self.auth_enabled:
|
|
|
return
|
|
|
if not self.client:
|
|
|
- self.client = R2RClient(tool_settings.R2R_BASE_URL)
|
|
|
- self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
|
|
|
+ self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
|
|
|
+ await self.client.users.login(
|
|
|
+ tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
|
|
|
+ )
|
|
|
print(self.client.access_token)
|
|
|
|
|
|
def ingest_file(self, file_path: str, metadata: Optional[dict]):
|
|
|
self._check_login()
|
|
|
- ingest_response = self.client.documents.create(
|
|
|
- file_path=file_path, metadata=metadata if metadata else None, id=None
|
|
|
+ ingest_response = asyncio.run(
|
|
|
+ self.client.documents.create(
|
|
|
+ file_path=file_path, metadata=metadata if metadata else None, id=None
|
|
|
+ )
|
|
|
)
|
|
|
return ingest_response.get("results")
|
|
|
|
|
|
def search(self, query: str, filters: dict[str, Any]):
|
|
|
self._check_login()
|
|
|
- search_response = self.client.retrieval.search(
|
|
|
- query=query,
|
|
|
- search_settings={
|
|
|
- "filters": filters,
|
|
|
- "limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
|
- },
|
|
|
+ search_response = asyncio.run(
|
|
|
+ self.client.retrieval.search(
|
|
|
+ query=query,
|
|
|
+ search_settings={
|
|
|
+ "filters": filters,
|
|
|
+ "limit": tool_settings.R2R_SEARCH_LIMIT,
|
|
|
+ },
|
|
|
+ )
|
|
|
)
|
|
|
return search_response.get("results").get("chunk_search_results")
|
|
|
|