jack 1 week ago
parent
commit
cefba0a222
1 changed files with 34 additions and 32 deletions
  1. 34 32
      app/providers/r2r.py

+ 34 - 32
app/providers/r2r.py

@@ -22,23 +22,24 @@ class R2R:
     def init_sync(self):
         if not self.auth_enabled:
             return
-        if not self.client_sync:
-            self.client_sync = R2RClient(tool_settings.R2R_BASE_URL, 300)
-            self.client_sync.users.login(
-                tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
-            )
+        # if not self.client_sync:
+        self.client_sync = R2RClient(tool_settings.R2R_BASE_URL)
+        self.client_sync.users.login(
+            tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
+        )
         print(
             "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
         )
         # print(tool_settings.R2R_USERNAME)
         # print(tool_settings.R2R_PASSWORD)
         print(self.client_sync)
+        return self.client_sync
 
     async def init(self):
         if not self.auth_enabled:
             return
-        if not self.client:
-            self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL, 300)
+        # if not self.client:
+        self.client = R2RAsyncClient(tool_settings.R2R_BASE_URL)
         print(
             "1111111111111111111111111111111122222vvdgdfdf" + tool_settings.R2R_USERNAME
         )
@@ -48,10 +49,11 @@ class R2R:
             tool_settings.R2R_USERNAME, tool_settings.R2R_PASSWORD
         )
         print(self.client.access_token)
+        return self.client
 
     async def ingest_file(self, file_path: str, metadata: Optional[dict]):
-        await self._check_login()
-        return await self.client.documents.create(
+        client = await self._check_login()
+        return await client.documents.create(
             file_path=file_path,
             metadata=metadata if metadata else None,
             ingestion_mode="fast",
@@ -59,19 +61,19 @@ class R2R:
         )
 
     async def ingest_fileinfo(self, file: UploadFile, metadata: Optional[dict]):
-        await self._check_login()
-        return await self.client.documents.create(
+        client = await self._check_login()
+        return await client.documents.create(
             file=file,
             metadata=metadata if metadata else None,
             id=None,
         )
 
     def search(self, query: str, filters: dict[str, Any]):
-        self._check_login_sync()
+        client = self._check_login_sync()
         print(
             "aaaaaaaaaaaaaaaaaaaaaaaaaaaasssssssssssssssssssssssssssssssssssssssssgggggggggggggggggggg"
         )
-        search_response = self.client_sync.retrieval.search(
+        search_response = client.retrieval.search(
             query=query,
             search_mode="basic",
             search_settings={
@@ -85,14 +87,14 @@ class R2R:
         return search_response.results.chunk_search_results
 
     def list_chunks(self, ids: list[str] = []):
-        self._check_login_sync()
+        client = self._check_login_sync()
         print(
             "retrieve_documentsretrieve_documentsretrieve_documentsretrieve_documentsretrieve_documents"
         )
         print(ids)
         allfile = []
         for id in ids:
-            listed = self.client_sync.documents.list_chunks(id=id)
+            listed = client.documents.list_chunks(id=id)
             allfile += listed.results
         return allfile
 
@@ -102,7 +104,7 @@ class R2R:
         offset: Optional[int] = 0,
         limit: Optional[int] = 100,
     ):
-        self._check_login_sync()
+        client = self._check_login_sync()
 
         """
             docs = client.collections.list_documents(empty_coll_id).results
@@ -113,7 +115,7 @@ class R2R:
         )
         if id != "":
             try:
-                listed = self.client_sync.collections.list_documents(
+                listed = client.collections.list_documents(
                     id=id, limit=limit, offset=offset
                 )
                 print(listed.results)
@@ -129,25 +131,25 @@ class R2R:
     async def _check_login(self):
         if not self.auth_enabled:
             return
-        if self.client.access_token and verify_jwt_expiration(self.client.access_token):
-            return
-        else:
-            await self.init()
+        # if self.client.access_token and verify_jwt_expiration(self.client.access_token):
+        #    return
+        # else:
+        return await self.init()
 
     def _check_login_sync(self):
         print("access_tokenaccess_tokenaccess_tokenaccess_token")
         print(self.client_sync)
         if not self.auth_enabled:
             return
-        try:
-            if self.client_sync.access_token and verify_jwt_expiration(
-                self.client_sync.access_token
-            ):
-                print(self.client_sync.access_token)
-                return
-        except Exception as e:
-            print(e)
-        self.init_sync()
+        # try:
+        #    if self.client_sync.access_token and verify_jwt_expiration(
+        #        self.client_sync.access_token
+        #     ):
+        #         print(self.client_sync.access_token)
+        #         return
+        # except Exception as e:
+        #     print(e)
+        return self.init_sync()
 
 
 # 创建 R2R 实例
@@ -155,5 +157,5 @@ r2r = R2R()
 
 
 # 在您的应用程序启动时调用 initialize_r2r()
-async def initialize_r2r():
-    await r2r.init()
+# async def initialize_r2r():
+#    await r2r.init()