jack há 4 meses atrás
pai
commit
971aa5e266
1 ficheiros alterados com 18 adições e 12 exclusões
  1. 18 12
      app/providers/r2r.py

+ 18 - 12
app/providers/r2r.py

@@ -1,5 +1,5 @@
 from typing import Optional, Any
-from r2r import R2RClient
+from r2r import R2RAsyncClient
 
 from app.libs.util import verify_jwt_expiration
 from config.llm import tool_settings
@@ -12,7 +12,7 @@ import asyncio
 
 
 class R2R:
-    client: R2RClient
+    client: R2RAsyncClient
 
     def __init__(self):
         self.auth_enabled = tool_settings.R2R_USERNAME and tool_settings.R2R_PASSWORD
@@ -22,25 +22,31 @@ class R2R:
         if not self.auth_enabled:
             return
         if not self.client:
-            self.client = R2RClient(tool_settings.R2R_BASE_URL)
-        self.client.users.login(tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD)
+            self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, "/v3")
+        await self.client.users.login(
+            tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
+        )
         print(self.client.access_token)
 
     def ingest_file(self, file_path: str, metadata: Optional[dict]):
         self._check_login()
-        ingest_response = self.client.documents.create(
-            file_path=file_path, metadata=metadata if metadata else None, id=None
+        ingest_response = asyncio.run(
+            self.client.documents.create(
+                file_path=file_path, metadata=metadata if metadata else None, id=None
+            )
         )
         return ingest_response.get("results")
 
     def search(self, query: str, filters: dict[str, Any]):
         self._check_login()
-        search_response = self.client.retrieval.search(
-            query=query,
-            search_settings={
-                "filters": filters,
-                "limit": tool_settings.R2R_SEARCH_LIMIT,
-            },
+        search_response = asyncio.run(
+            self.client.retrieval.search(
+                query=query,
+                search_settings={
+                    "filters": filters,
+                    "limit": tool_settings.R2R_SEARCH_LIMIT,
+                },
+            )
         )
         return search_response.get("results").get("chunk_search_results")